mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
support MLB reshaping on-axis for evenly sharded (#3484)
* support MLB reshaping on-axis for evenly sharded * update test * not -> !=
This commit is contained in:
parent
358a24eae6
commit
5cfcc2a8d7
2 changed files with 57 additions and 3 deletions
|
|
@ -307,6 +307,54 @@ class TestMultiTensor(unittest.TestCase):
|
|||
# for i, ast in enumerate(asts):
|
||||
# print(f"{i} {ast}")
|
||||
|
||||
def test_reshape_on_axis(self):
|
||||
devices = (d0, d1, d2)
|
||||
|
||||
t0 = Tensor.rand((26, 15, 7)).shard(devices, axis=1)
|
||||
|
||||
# test split and rejoin to the right
|
||||
t1 = t0.reshape((26, 3, 5, 7))
|
||||
t2 = t0.reshape((26, 3, 35))
|
||||
t3 = t1.reshape((26, 15, 7))
|
||||
t4 = t2.reshape((26, 105,))
|
||||
|
||||
for t in [t0, t1, t2, t3, t4]:
|
||||
assert t.lazydata.axis == 1
|
||||
np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten())
|
||||
|
||||
# test shape-one axis
|
||||
t5 = t4.reshape((26, 1, 105))
|
||||
assert t5.lazydata.axis == 2
|
||||
np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten())
|
||||
|
||||
# test split and rejoin to the right and reshape to the left
|
||||
t5 = t0.reshape((2, 13, 3, 5, 7))
|
||||
t6 = t0.reshape((13, 2, 3, 7, 5))
|
||||
t7 = t0.reshape((1, 13, 2, 3, 1, 7, 5))
|
||||
np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten())
|
||||
assert t5.lazydata.axis == 2
|
||||
np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten())
|
||||
assert t6.lazydata.axis == 2
|
||||
np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten())
|
||||
assert t7.lazydata.axis == 3
|
||||
|
||||
# test no left join
|
||||
with self.assertRaises((AssertionError, ValueError)):
|
||||
t0.reshape((26*15,7))
|
||||
|
||||
def test_reshape_on_axis_uneven(self):
|
||||
devices = (d0, d1, d2)
|
||||
t0 = Tensor.rand((4, 8, 15)).shard(devices, axis=1)
|
||||
|
||||
# no split axis if uneven
|
||||
with self.assertRaises((AssertionError, ValueError)):
|
||||
t0.reshape((4,4,2,15))
|
||||
|
||||
# ok to split reshape left and right though
|
||||
t1 = t0.reshape(2, 2, 8, 3, 5)
|
||||
np.testing.assert_allclose(t0.numpy().flatten(), t1.numpy().flatten())
|
||||
assert t1.lazydata.axis == 2
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
||||
class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
# shrink a multitensor on sharded axis
|
||||
|
|
|
|||
|
|
@ -98,9 +98,15 @@ class MultiLazyBuffer:
|
|||
def reshape(self, arg:Tuple[sint, ...]):
|
||||
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
|
||||
arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
|
||||
# new_axis is the one that preserves prod(prior to new_axis) and prod(post to new_axis)
|
||||
new_axis = [tuple(p) for p in zip(arg_acc, arg_acc[1:])].index((prod(self.shape[:self.axis]), prod(self.shape[:self.axis+1])))
|
||||
return MultiLazyBuffer([x.reshape(tuple(x.shape[self.axis] if a == new_axis else s for a,s in enumerate(arg))) for x in self.lbs],
|
||||
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
||||
# todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
||||
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
|
||||
if arg[new_axis] != self.shape[self.axis]:
|
||||
assert self.shape[self.axis] % len(self.real_lbs) == 0, f"cannot reshape on-axis for uneven shard {self.axis} {self.shape} {len(self.real_lbs)}"
|
||||
assert arg[new_axis] % len(self.real_lbs) == 0, f"new on-axis shape must divide evenly between devices {new_axis} {arg} {len(self.real_lbs)}"
|
||||
return MultiLazyBuffer([x.reshape(tuple(s if a != new_axis else
|
||||
x.shape[self.axis] if s == self.shape[self.axis] else
|
||||
s // len(self.real_lbs) for a,s in enumerate(arg))) for x in self.lbs],
|
||||
new_axis, self.real)
|
||||
|
||||
def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue