support MLB reshaping on-axis for evenly sharded (#3484)

* support MLB reshaping on-axis for evenly sharded

* update test

* not -> !=
This commit is contained in:
David Hou 2024-02-23 04:51:36 -08:00 committed by GitHub
commit 5cfcc2a8d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 57 additions and 3 deletions

View file

@ -307,6 +307,54 @@ class TestMultiTensor(unittest.TestCase):
# for i, ast in enumerate(asts):
# print(f"{i} {ast}")
def test_reshape_on_axis(self):
devices = (d0, d1, d2)
t0 = Tensor.rand((26, 15, 7)).shard(devices, axis=1)
# test split and rejoin to the right
t1 = t0.reshape((26, 3, 5, 7))
t2 = t0.reshape((26, 3, 35))
t3 = t1.reshape((26, 15, 7))
t4 = t2.reshape((26, 105,))
for t in [t0, t1, t2, t3, t4]:
assert t.lazydata.axis == 1
np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten())
# test shape-one axis
t5 = t4.reshape((26, 1, 105))
assert t5.lazydata.axis == 2
np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten())
# test split and rejoin to the right and reshape to the left
t5 = t0.reshape((2, 13, 3, 5, 7))
t6 = t0.reshape((13, 2, 3, 7, 5))
t7 = t0.reshape((1, 13, 2, 3, 1, 7, 5))
np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten())
assert t5.lazydata.axis == 2
np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten())
assert t6.lazydata.axis == 2
np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten())
assert t7.lazydata.axis == 3
# test no left join
with self.assertRaises((AssertionError, ValueError)):
t0.reshape((26*15,7))
def test_reshape_on_axis_uneven(self):
devices = (d0, d1, d2)
t0 = Tensor.rand((4, 8, 15)).shard(devices, axis=1)
# no split axis if uneven
with self.assertRaises((AssertionError, ValueError)):
t0.reshape((4,4,2,15))
# ok to split reshape left and right though
t1 = t0.reshape(2, 2, 8, 3, 5)
np.testing.assert_allclose(t0.numpy().flatten(), t1.numpy().flatten())
assert t1.lazydata.axis == 2
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
# shrink a multitensor on sharded axis

View file

@ -98,9 +98,15 @@ class MultiLazyBuffer:
def reshape(self, arg:Tuple[sint, ...]):
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
# new_axis is the one that preserves prod(prior to new_axis) and prod(post to new_axis)
new_axis = [tuple(p) for p in zip(arg_acc, arg_acc[1:])].index((prod(self.shape[:self.axis]), prod(self.shape[:self.axis+1])))
return MultiLazyBuffer([x.reshape(tuple(x.shape[self.axis] if a == new_axis else s for a,s in enumerate(arg))) for x in self.lbs],
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
# todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
if arg[new_axis] != self.shape[self.axis]:
assert self.shape[self.axis] % len(self.real_lbs) == 0, f"cannot reshape on-axis for uneven shard {self.axis} {self.shape} {len(self.real_lbs)}"
assert arg[new_axis] % len(self.real_lbs) == 0, f"new on-axis shape must divide evenly between devices {new_axis} {arg} {len(self.real_lbs)}"
return MultiLazyBuffer([x.reshape(tuple(s if a != new_axis else
x.shape[self.axis] if s == self.shape[self.axis] else
s // len(self.real_lbs) for a,s in enumerate(arg))) for x in self.lbs],
new_axis, self.real)
def pad(self, arg:Tuple[Tuple[sint, sint], ...]):