UOp.shard support axis=None [PR] (#16538)

match Tensor
This commit is contained in:
chenyu 2026-06-08 11:36:50 -04:00 committed by GitHub
commit 12764161c9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 10 additions and 2 deletions

View file

@ -204,6 +204,12 @@ class TestMultiAxis(unittest.TestCase):
self.assertEqual(t.reshape(2, 16).uop.axis, 0)
self.assertEqual(t.reshape(2, 2, 8).uop.axis, 0)
def test_uop_shard_axis_none(self):
devices = ("NULL:0", "NULL:1")
u = Tensor.ones(8).contiguous().realize().uop
self.assertIsNone(u.shard(devices).axis)
self.assertEqual(u.shard(devices, 0).axis, 0)
def test_empty_like_sharded(self):
t = Tensor.ones(4, 8).shard(("NULL:0", "NULL:1"), axis=0)
e = t.empty_like()

View file

@ -377,7 +377,7 @@ class Tensor(OpMixin):
if not isinstance(self.device, str): raise RuntimeError("can't shard a multi-device tensor")
if len(devices) == 1: return self.to(devices[0])
devices = cast(tuple[str, ...], canonicalize_device(devices))
uop = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
uop = self.uop.shard(devices, None if axis is None else self._resolve_dim(axis))
return Tensor(uop).is_param_(self.is_param)
def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor:

View file

@ -656,7 +656,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.shape[axis] % dcount != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {dcount=}")
sz = self.shape[axis] // dcount
return self.shrink(tuple((0,s) if i != axis else (dnum*sz,dnum*sz+sz) for i,s in enumerate(self.shape)))
def shard(self, devices:tuple[str, ...], axis:int) -> UOp: return self.copy_to_device(devices)._shard(axis, len(devices)).multi(axis)
def shard(self, devices:tuple[str, ...], axis:int|None=None) -> UOp:
copied = self.copy_to_device(devices)
return copied if axis is None else copied._shard(axis, len(devices)).multi(axis)
def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None):
assert arg is None or isinstance(self.device, tuple)