mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
ebc5390c9a
commit
12764161c9
3 changed files with 10 additions and 2 deletions
|
|
@ -204,6 +204,12 @@ class TestMultiAxis(unittest.TestCase):
|
|||
self.assertEqual(t.reshape(2, 16).uop.axis, 0)
|
||||
self.assertEqual(t.reshape(2, 2, 8).uop.axis, 0)
|
||||
|
||||
def test_uop_shard_axis_none(self):
|
||||
devices = ("NULL:0", "NULL:1")
|
||||
u = Tensor.ones(8).contiguous().realize().uop
|
||||
self.assertIsNone(u.shard(devices).axis)
|
||||
self.assertEqual(u.shard(devices, 0).axis, 0)
|
||||
|
||||
def test_empty_like_sharded(self):
|
||||
t = Tensor.ones(4, 8).shard(("NULL:0", "NULL:1"), axis=0)
|
||||
e = t.empty_like()
|
||||
|
|
|
|||
|
|
@ -377,7 +377,7 @@ class Tensor(OpMixin):
|
|||
if not isinstance(self.device, str): raise RuntimeError("can't shard a multi-device tensor")
|
||||
if len(devices) == 1: return self.to(devices[0])
|
||||
devices = cast(tuple[str, ...], canonicalize_device(devices))
|
||||
uop = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
|
||||
uop = self.uop.shard(devices, None if axis is None else self._resolve_dim(axis))
|
||||
return Tensor(uop).is_param_(self.is_param)
|
||||
|
||||
def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor:
|
||||
|
|
|
|||
|
|
@ -656,7 +656,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if self.shape[axis] % dcount != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {dcount=}")
|
||||
sz = self.shape[axis] // dcount
|
||||
return self.shrink(tuple((0,s) if i != axis else (dnum*sz,dnum*sz+sz) for i,s in enumerate(self.shape)))
|
||||
def shard(self, devices:tuple[str, ...], axis:int) -> UOp: return self.copy_to_device(devices)._shard(axis, len(devices)).multi(axis)
|
||||
def shard(self, devices:tuple[str, ...], axis:int|None=None) -> UOp:
|
||||
copied = self.copy_to_device(devices)
|
||||
return copied if axis is None else copied._shard(axis, len(devices)).multi(axis)
|
||||
|
||||
def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None):
|
||||
assert arg is None or isinstance(self.device, tuple)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue