pass num_devices to UnsyncedBatchNorm in test, allow UnsyncedBatchNorm to be used with LB

This commit is contained in:
David Hou 2024-02-22 14:51:16 -08:00
commit 99536555e4
2 changed files with 2 additions and 2 deletions

View file

@ -40,7 +40,7 @@ class UnsyncedBatchNorm:
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
def __call__(self, x:Tensor):
assert isinstance(x.lazydata, MultiLazyBuffer) and len(x.lazydata.lbs) == self.num_devices and x.lazydata.axis == 0
if isinstance(x.lazydata, MultiLazyBuffer): assert x.lazydata.axis is None or x.lazydata.axis == 0 and len(x.lazydata.lbs) == self.num_devices
rshape, x = x.shape, x.reshape(self.num_devices, -1, *x.shape[1:])
batch_mean, batch_invstd = self.calc_stats(x)

View file

@ -582,7 +582,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
with Tensor.train():
synced_bn = BatchNorm2d(8)
unsynced_bn = UnsyncedBatchNorm(8)
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))
for p in get_parameters([synced_bn, unsynced_bn]):
p.shard_(devices)