manually shard running mean and running var

This commit is contained in:
David Hou 2024-02-29 12:12:47 -08:00
commit e0fd41bb46

View file

@ -305,8 +305,11 @@ def train_cifar():
X_test, Y_test = X_test.cast(dtypes.default_float), Y_test.cast(dtypes.default_float)
if len(GPUS) > 1:
for x in get_parameters(model):
x.to_(GPUS)
for k, x in get_state_dict(model):
if 'running_mean' in k or 'running_bias' in k:
x.shard_(GPUS, axis=0)
else:
x.to_(GPUS)
# parse the training params into bias and non-bias
params_dict = get_state_dict(model)