gather to single device (#16354)

This commit is contained in:
wozeparrot 2026-05-25 20:27:08 -04:00 committed by GitHub
commit 76fc39ccc0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 1 deletions

View file

@ -239,7 +239,28 @@ class TestMultiTensor(unittest.TestCase):
kernel_counts[ring] = GlobalCounters.kernel_count
self.assertEqual(t.device, Device.DEFAULT)
np.testing.assert_equal(t.numpy(), np.arange(32))
self.assertNotEqual(kernel_counts[0], kernel_counts[2])
self.assertEqual(kernel_counts[0], kernel_counts[2])
def test_to_single_device_gather_memory(self):
nrows, ncols = 64, 1024
nbytes = nrows*ncols*4
for devs in (devices_2, devices_4):
ndev = len(devs)
for axis in (0, 1):
sh = Tensor.arange(nrows*ncols).reshape(nrows, ncols).clone().shard(devs, axis).realize()
kernels, mem = {}, {}
for ring in (0, 2):
GlobalCounters.reset()
with Context(RING=ring, SCACHE=0):
t = sh.to(Device.DEFAULT)
t.realize()
kernels[ring], mem[ring] = GlobalCounters.kernel_count, GlobalCounters.global_mem
self.assertEqual(t.device, Device.DEFAULT)
np.testing.assert_equal(t.numpy(), np.arange(nrows*ncols).reshape(nrows, ncols))
self.assertEqual(kernels[0], kernels[2])
self.assertEqual(mem[0], mem[2])
self.assertLess(kernels[0], 2*ndev)
self.assertLessEqual(mem[0], 4*nbytes)
def test_allreduce_all2all(self):
with Context(ALL2ALL=2):

View file

@ -111,6 +111,9 @@ def flip_multi(root:UOp, multi:UOp):
def copy_multi(multi:UOp, device:str | tuple[str, ...] | UOp):
assert multi.axis is not None, "all multi ops have axis"
if isinstance(device, UOp) and isinstance(device.arg, str):
pieces = [multi.src[0].mselect(i).copy_to_device(device) for i in range(len(multi.device))]
return pieces[0].cat(*pieces[1:], dim=multi.axis)
return multi.src[0]._unshard(multi.axis).allreduce(Ops.ADD, device)
def store_after_multi(dest:UOp, src:UOp): return dest.after(dest.store(src.src[0])).multi(src.axis)