mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
resnet mlperf logging (#4361)
* resnet mlperf logging * cropping too much?
This commit is contained in:
parent
f635c4d273
commit
22376e53b7
4 changed files with 100 additions and 7 deletions
|
|
@ -19,6 +19,26 @@ def train_resnet():
|
|||
from examples.mlperf.initializers import Conv2dHeNormal, Linear
|
||||
from examples.hlb_cifar10 import UnsyncedBatchNorm
|
||||
|
||||
INITMLPERF = getenv("INITMLPERF")
|
||||
RUNMLPERF = getenv("RUNMLPERF")
|
||||
if getenv("LOGMLPERF"):
|
||||
from mlperf_logging import mllog
|
||||
import mlperf_logging.mllog.constants as mllog_constants
|
||||
MLLOGGER = mllog.get_mllogger()
|
||||
if INITMLPERF:
|
||||
# common.yaml
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
|
||||
MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
|
||||
MLLOGGER.start(key=mllog_constants.INIT_START)
|
||||
# closed_common.yaml
|
||||
MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=mllog_constants.RESNET)
|
||||
MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
|
||||
else:
|
||||
MLLOGGER = None
|
||||
|
||||
config = {}
|
||||
seed = config["seed"] = getenv("SEED", 42)
|
||||
Tensor.manual_seed(seed) # seed for weight initialization
|
||||
|
|
@ -86,6 +106,31 @@ def train_resnet():
|
|||
scheduler_group = LRSchedulerGroup(scheduler, scheduler_skip)
|
||||
print(f"training with batch size {BS} for {epochs} epochs")
|
||||
|
||||
# log mlperf hparams
|
||||
if MLLOGGER:
|
||||
if INITMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.INIT_START)
|
||||
|
||||
MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=BS)
|
||||
from extra.datasets.imagenet import get_train_files, get_val_files
|
||||
MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=len(get_train_files()))
|
||||
MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=len(get_val_files()))
|
||||
|
||||
MLLOGGER.event(key=mllog_constants.OPT_NAME, value="lars")
|
||||
assert scheduler.initial_lr == scheduler_skip.initial_lr
|
||||
assert scheduler.end_lr == scheduler_skip.end_lr
|
||||
assert scheduler.power == scheduler_skip.power
|
||||
MLLOGGER.event(key=mllog_constants.LARS_OPT_BASE_LEARNING_RATE, value=scheduler.initial_lr)
|
||||
MLLOGGER.event(key=mllog_constants.LARS_OPT_END_LR, value=scheduler.end_lr)
|
||||
MLLOGGER.event(key=mllog_constants.LARS_OPT_LR_DECAY_POLY_POWER, value=scheduler.power)
|
||||
MLLOGGER.event(key=mllog_constants.LARS_OPT_LR_DECAY_STEPS, value=epochs)
|
||||
MLLOGGER.event(key=mllog_constants.LARS_EPSILON, value=0) # does not support epsilon != 0
|
||||
MLLOGGER.event(key=mllog_constants.LARS_OPT_LEARNING_RATE_WARMUP_EPOCHS, value=lr_warmup_epochs)
|
||||
MLLOGGER.event(key=mllog_constants.LARS_OPT_MOMENTUM, value=optimizer.momentum)
|
||||
MLLOGGER.event(key=mllog_constants.LARS_OPT_WEIGHT_DECAY, value=optimizer.wd)
|
||||
if RUNMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.RUN_START)
|
||||
|
||||
# ** resume from checkpointing **
|
||||
start_epoch = 0
|
||||
if ckpt:=getenv("RESUME", ""):
|
||||
|
|
@ -136,6 +181,8 @@ def train_resnet():
|
|||
step_times = []
|
||||
for e in range(start_epoch, epochs):
|
||||
# ** train loop **
|
||||
if MLLOGGER:
|
||||
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=e+1, metadata=dict(epoch_num=e+1))
|
||||
Tensor.training = True
|
||||
BEAM.value = TRAIN_BEAM
|
||||
batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e)
|
||||
|
|
@ -186,9 +233,13 @@ def train_resnet():
|
|||
# if we are doing beam search, run the first eval too
|
||||
if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
|
||||
return
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.event(key=mllog_constants.EPOCH_STOP, value=e+1, metadata=dict(epoch_num=e+1))
|
||||
|
||||
# ** eval loop **
|
||||
if (e + 1 - eval_start_epoch) % eval_epochs == 0 and steps_in_val_epoch > 0:
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.EVAL_START, value=e+1, metadata=dict(epoch_num=e+1))
|
||||
if getenv("RESET_STEP", 1): train_step.reset() # free the train step memory :(
|
||||
eval_times = []
|
||||
eval_loss = 0.0
|
||||
|
|
@ -217,7 +268,10 @@ def train_resnet():
|
|||
eval_num_samples += num_samples
|
||||
proc, next_proc = next_proc, None # return old cookie
|
||||
i += 1
|
||||
if i == BENCHMARK: return
|
||||
if i == BENCHMARK:
|
||||
if MLLOGGER and INITMLPERF:
|
||||
MLLOGGER.event(key=mllog_constants.INIT_STOP)
|
||||
return
|
||||
|
||||
et = time.time()
|
||||
eval_times.append(et - st)
|
||||
|
|
@ -231,6 +285,9 @@ def train_resnet():
|
|||
tqdm.write(f"eval loss: {total_loss:.2f}, eval time: {total_fw_time:.2f}, eval top 1 acc: {total_top_1:.3f}")
|
||||
if WANDB:
|
||||
wandb.log({"eval/loss": total_loss, "eval/top_1_acc": total_top_1, "eval/forward_time": total_fw_time, "epoch": e + 1})
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=total_top_1, metadata=dict(epoch_num=e+1))
|
||||
MLLOGGER.event(key=mllog_constants.EVAL_STOP, value=e+1, metadata=dict(epoch_num=e+1))
|
||||
|
||||
# save model if achieved target
|
||||
if not achieved and total_top_1 >= target:
|
||||
|
|
@ -240,6 +297,8 @@ def train_resnet():
|
|||
print(f" *** Model saved to {fn} ***")
|
||||
achieved = True
|
||||
# stop once achieve the target
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.event(key=mllog_constants.RUN_STOP, metadata=dict(status="success"))
|
||||
break
|
||||
|
||||
# checkpoint every time we eval
|
||||
|
|
|
|||
|
|
@ -0,0 +1,17 @@
|
|||
export PYTHONPATH="."
|
||||
export MODEL="resnet"
|
||||
export SUBMISSION_PLATFORM="tinybox_green"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=48 LR=7
|
||||
|
||||
export SPLIT_REDUCEOP=1 LAZYCACHE=0 RESET_STEP=0
|
||||
|
||||
export TRAIN_BEAM=3 IGNORE_JIT_FIRST_BEAM=1 BEAM_UOPS_MAX=1500 BEAM_UPCAST_MAX=128 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=50
|
||||
|
||||
# pip install -e ".[mlperf]"
|
||||
export LOGMLPERF=1
|
||||
|
||||
# init
|
||||
BENCHMARK=10 INITMLPERF=1 python3 examples/mlperf/model_train.py | tee resnet.log
|
||||
|
||||
# run
|
||||
WANDB=1 PARALLEL=0 RUNMLPERF=1 python3 examples/mlperf/model_train.py | tee -a resnet.log
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
export PYTHONPATH="."
|
||||
export MODEL="resnet"
|
||||
export SUBMISSION_PLATFORM="tinybox_red"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=1536 EVAL_BS=48 LR=7
|
||||
|
||||
export SPLIT_REDUCEOP=1 LAZYCACHE=0 RESET_STEP=0
|
||||
|
||||
export TRAIN_BEAM=4 IGNORE_JIT_FIRST_BEAM=1 BEAM_UOPS_MAX=1500 BEAM_UPCAST_MAX=128 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=25
|
||||
|
||||
# pip install -e ".[mlperf]"
|
||||
export LOGMLPERF=1
|
||||
|
||||
# init
|
||||
BENCHMARK=10 INITMLPERF=1 python3 examples/mlperf/model_train.py | tee resnet.log
|
||||
|
||||
# run
|
||||
WANDB=1 PARALLEL=0 RUNMLPERF=1 python3 examples/mlperf/model_train.py | tee -a resnet.log
|
||||
|
|
@ -60,7 +60,7 @@ def random_resized_crop(img, size, scale=(0.10, 1.0), ratio=(3/4, 4/3)):
|
|||
|
||||
# Crop
|
||||
random_solution_found = False
|
||||
for _ in range(10):
|
||||
for _ in range(100):
|
||||
aspect_ratio = random.uniform(ratio[0], ratio[1])
|
||||
max_scale = min(min(w * aspect_ratio / h, h / aspect_ratio / w), scale[1])
|
||||
target_area = area * random.uniform(scale[0], max_scale)
|
||||
|
|
@ -69,12 +69,12 @@ def random_resized_crop(img, size, scale=(0.10, 1.0), ratio=(3/4, 4/3)):
|
|||
h_new = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if 0 < w_new <= w and 0 < h_new <= h:
|
||||
crop_left = random.randint(0, w - w_new + 1)
|
||||
crop_top = random.randint(0, h - h_new + 1)
|
||||
crop_left = random.randint(0, w - w_new)
|
||||
crop_top = random.randint(0, h - h_new)
|
||||
|
||||
img = img.crop((crop_left, crop_top, crop_left + w_new, crop_top + h_new))
|
||||
random_solution_found = True
|
||||
break
|
||||
img = img.crop((crop_left, crop_top, crop_left + w_new, crop_top + h_new))
|
||||
random_solution_found = True
|
||||
break
|
||||
|
||||
if not random_solution_found:
|
||||
# Center crop
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue