mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
gather to single device (#16354)
This commit is contained in:
parent
942cb42b97
commit
76fc39ccc0
2 changed files with 25 additions and 1 deletions
|
|
@ -239,7 +239,28 @@ class TestMultiTensor(unittest.TestCase):
|
|||
kernel_counts[ring] = GlobalCounters.kernel_count
|
||||
self.assertEqual(t.device, Device.DEFAULT)
|
||||
np.testing.assert_equal(t.numpy(), np.arange(32))
|
||||
self.assertNotEqual(kernel_counts[0], kernel_counts[2])
|
||||
self.assertEqual(kernel_counts[0], kernel_counts[2])
|
||||
|
||||
def test_to_single_device_gather_memory(self):
|
||||
nrows, ncols = 64, 1024
|
||||
nbytes = nrows*ncols*4
|
||||
for devs in (devices_2, devices_4):
|
||||
ndev = len(devs)
|
||||
for axis in (0, 1):
|
||||
sh = Tensor.arange(nrows*ncols).reshape(nrows, ncols).clone().shard(devs, axis).realize()
|
||||
kernels, mem = {}, {}
|
||||
for ring in (0, 2):
|
||||
GlobalCounters.reset()
|
||||
with Context(RING=ring, SCACHE=0):
|
||||
t = sh.to(Device.DEFAULT)
|
||||
t.realize()
|
||||
kernels[ring], mem[ring] = GlobalCounters.kernel_count, GlobalCounters.global_mem
|
||||
self.assertEqual(t.device, Device.DEFAULT)
|
||||
np.testing.assert_equal(t.numpy(), np.arange(nrows*ncols).reshape(nrows, ncols))
|
||||
self.assertEqual(kernels[0], kernels[2])
|
||||
self.assertEqual(mem[0], mem[2])
|
||||
self.assertLess(kernels[0], 2*ndev)
|
||||
self.assertLessEqual(mem[0], 4*nbytes)
|
||||
|
||||
def test_allreduce_all2all(self):
|
||||
with Context(ALL2ALL=2):
|
||||
|
|
|
|||
|
|
@ -111,6 +111,9 @@ def flip_multi(root:UOp, multi:UOp):
|
|||
|
||||
def copy_multi(multi:UOp, device:str | tuple[str, ...] | UOp):
|
||||
assert multi.axis is not None, "all multi ops have axis"
|
||||
if isinstance(device, UOp) and isinstance(device.arg, str):
|
||||
pieces = [multi.src[0].mselect(i).copy_to_device(device) for i in range(len(multi.device))]
|
||||
return pieces[0].cat(*pieces[1:], dim=multi.axis)
|
||||
return multi.src[0]._unshard(multi.axis).allreduce(Ops.ADD, device)
|
||||
|
||||
def store_after_multi(dest:UOp, src:UOp): return dest.after(dest.store(src.src[0])).multi(src.axis)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue