mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
pass num_devices to UnsyncedBatchNorm in test, allow UnsyncedBatchNorm to be used with LB
This commit is contained in:
parent
48564f0c47
commit
99536555e4
2 changed files with 2 additions and 2 deletions
|
|
@ -40,7 +40,7 @@ class UnsyncedBatchNorm:
|
|||
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
assert isinstance(x.lazydata, MultiLazyBuffer) and len(x.lazydata.lbs) == self.num_devices and x.lazydata.axis == 0
|
||||
if isinstance(x.lazydata, MultiLazyBuffer): assert x.lazydata.axis is None or x.lazydata.axis == 0 and len(x.lazydata.lbs) == self.num_devices
|
||||
|
||||
rshape, x = x.shape, x.reshape(self.num_devices, -1, *x.shape[1:])
|
||||
batch_mean, batch_invstd = self.calc_stats(x)
|
||||
|
|
|
|||
|
|
@ -582,7 +582,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
|||
|
||||
with Tensor.train():
|
||||
synced_bn = BatchNorm2d(8)
|
||||
unsynced_bn = UnsyncedBatchNorm(8)
|
||||
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))
|
||||
|
||||
for p in get_parameters([synced_bn, unsynced_bn]):
|
||||
p.shard_(devices)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue