cifar move GlobalCounters.reset() before shard (#3217)

* cifar move GlobalCounters.reset() before shard

also shard mini batch inplace

* don't eval with DISABLE_BACKWARD
This commit is contained in:
chenyu 2024-01-23 16:07:43 -05:00 committed by GitHub
commit 9e5409be6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -172,8 +172,8 @@ def train_cifar():
mask = make_square_mask(X.shape, mask_size)
order = list(range(0, X.shape[0]))
random.shuffle(order)
X_patch = Tensor(X.numpy()[order,...])
Y_patch = Tensor(Y.numpy()[order])
X_patch = Tensor(X.numpy()[order], device=X.device)
Y_patch = Tensor(Y.numpy()[order], device=Y.device)
X_cutmix = mask.where(X_patch, X)
mix_portion = float(mask_size**2)/(X.shape[-2]*X.shape[-1])
Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y
@ -326,7 +326,7 @@ def train_cifar():
with Tensor.train():
st = time.monotonic()
while i <= STEPS:
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1:
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
corrects = []
corrects_ema = []
@ -334,7 +334,8 @@ def train_cifar():
losses_ema = []
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
if len(GPUS) > 1:
Xt, Yt = Xt.shard(GPUS, axis=0), Yt.shard(GPUS, axis=0)
Xt.shard_(GPUS, axis=0)
Yt.shard_(GPUS, axis=0)
correct, loss = eval_step_jitted(model, Xt, Yt)
losses.append(loss.numpy().tolist())
@ -355,11 +356,12 @@ def train_cifar():
if STEPS == 0 or i == STEPS: break
GlobalCounters.reset()
X, Y = next(batcher)
if len(GPUS) > 1:
X, Y = X.shard(GPUS, axis=0), Y.shard(GPUS, axis=0)
X.shard_(GPUS, axis=0)
Y.shard_(GPUS, axis=0)
GlobalCounters.reset()
with Context(BEAM=getenv("LATEBEAM", BEAM.value), WINO=getenv("LATEWINO", WINO.value)):
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
et = time.monotonic()