This commit is contained in:
David Hou 2024-02-29 13:52:54 -08:00
commit 5e8a67a535

View file

@ -262,7 +262,7 @@ def train_cifar():
X_test, Y_test = X_test.cast(dtypes.default_float), Y_test.cast(dtypes.default_float)
if len(GPUS) > 1:
for k, x in get_state_dict(model):
for k, x in get_state_dict(model).items():
if not getenv('SYNCBN') and ('running_mean' in k or 'running_bias' in k):
x.shard_(GPUS, axis=0)
else: