mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
manually shard running mean and running var
This commit is contained in:
parent
06e07950f6
commit
e0fd41bb46
1 changed files with 5 additions and 2 deletions
|
|
@ -305,8 +305,11 @@ def train_cifar():
|
|||
X_test, Y_test = X_test.cast(dtypes.default_float), Y_test.cast(dtypes.default_float)
|
||||
|
||||
if len(GPUS) > 1:
|
||||
for x in get_parameters(model):
|
||||
x.to_(GPUS)
|
||||
for k, x in get_state_dict(model):
|
||||
if 'running_mean' in k or 'running_bias' in k:
|
||||
x.shard_(GPUS, axis=0)
|
||||
else:
|
||||
x.to_(GPUS)
|
||||
|
||||
# parse the training params into bias and non-bias
|
||||
params_dict = get_state_dict(model)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue