mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
cifar move GlobalCounters.reset() before shard (#3217)
* cifar move GlobalCounters.reset() before shard also shard mini batch inplace * don't eval with DISABLE_BACKWARD
This commit is contained in:
parent
595d05a250
commit
9e5409be6c
1 changed files with 8 additions and 6 deletions
|
|
@ -172,8 +172,8 @@ def train_cifar():
|
|||
mask = make_square_mask(X.shape, mask_size)
|
||||
order = list(range(0, X.shape[0]))
|
||||
random.shuffle(order)
|
||||
X_patch = Tensor(X.numpy()[order,...])
|
||||
Y_patch = Tensor(Y.numpy()[order])
|
||||
X_patch = Tensor(X.numpy()[order], device=X.device)
|
||||
Y_patch = Tensor(Y.numpy()[order], device=Y.device)
|
||||
X_cutmix = mask.where(X_patch, X)
|
||||
mix_portion = float(mask_size**2)/(X.shape[-2]*X.shape[-1])
|
||||
Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y
|
||||
|
|
@ -326,7 +326,7 @@ def train_cifar():
|
|||
with Tensor.train():
|
||||
st = time.monotonic()
|
||||
while i <= STEPS:
|
||||
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1:
|
||||
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):
|
||||
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
|
||||
corrects = []
|
||||
corrects_ema = []
|
||||
|
|
@ -334,7 +334,8 @@ def train_cifar():
|
|||
losses_ema = []
|
||||
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
|
||||
if len(GPUS) > 1:
|
||||
Xt, Yt = Xt.shard(GPUS, axis=0), Yt.shard(GPUS, axis=0)
|
||||
Xt.shard_(GPUS, axis=0)
|
||||
Yt.shard_(GPUS, axis=0)
|
||||
|
||||
correct, loss = eval_step_jitted(model, Xt, Yt)
|
||||
losses.append(loss.numpy().tolist())
|
||||
|
|
@ -355,11 +356,12 @@ def train_cifar():
|
|||
|
||||
if STEPS == 0 or i == STEPS: break
|
||||
|
||||
GlobalCounters.reset()
|
||||
X, Y = next(batcher)
|
||||
if len(GPUS) > 1:
|
||||
X, Y = X.shard(GPUS, axis=0), Y.shard(GPUS, axis=0)
|
||||
X.shard_(GPUS, axis=0)
|
||||
Y.shard_(GPUS, axis=0)
|
||||
|
||||
GlobalCounters.reset()
|
||||
with Context(BEAM=getenv("LATEBEAM", BEAM.value), WINO=getenv("LATEWINO", WINO.value)):
|
||||
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
|
||||
et = time.monotonic()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue