override output shape in fused assign (#5930)

* override output shape in fused assign

This makes

```
FUSE_ARANGE=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
```
work. In general we should assert ASSIGN doesn't change shape.

* merge asserts
This commit is contained in:
qazal 2024-08-06 18:28:50 +08:00 committed by GitHub
commit 3d4742dd2e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 20 additions and 3 deletions

View file

@ -1443,5 +1443,20 @@ class TestIndexing(unittest.TestCase):
run_schedule(check_schedule(ref, 3))
np.testing.assert_equal(fused.numpy(), ref.numpy())
def test_fuse_assign_contiguous(self):
x = Tensor.zeros(4, 4, dtype=dtypes.int).contiguous().realize()
a = Tensor.arange(8).reshape(4, 2)
self.check_schedule(x.shrink((None, (0, 2))).assign(a.contiguous()), 2)
np.testing.assert_equal(x.numpy(), [[0, 1, 0, 0], [2, 3, 0, 0], [4, 5, 0, 0], [6, 7, 0, 0]])
def test_assign_non_contiguous(self):
x = Tensor.zeros(4, 4, dtype=dtypes.int).contiguous().realize()
y = Tensor.randint(4, 2)
a = Tensor.arange(8).reshape(4, 2)+y
x.shrink((None, (0, 2))).assign(a).realize()
xref = np.zeros((4, 4), dtype=int)
xref[:, :2] = np.arange(8).reshape(4, 2)+y.numpy()
np.testing.assert_equal(x.numpy(), xref)
if __name__ == '__main__':
unittest.main(verbosity=2)

View file

@ -152,10 +152,12 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
for i, out in enumerate(outs):
output_st = ShapeTracker.from_shape(reduce_st(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape)
lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, reduce_info, cache=cache)
output_view = out.arg[0] if out.op is MetaOps.ASSIGN and out.arg else output_st
output_view, vv = output_view.simplify().unbind()
if out.op is MetaOps.ASSIGN and out.arg:
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
output_st = out.arg[0].reshape(output_st.shape)
output_st, vv = output_st.simplify().unbind()
if vv: var_vals.update(vv)
ast.append(LazyOp(BufferOps.STORE, (lop,), MemBuffer(i, out.dtype, output_view)))
ast.append(LazyOp(BufferOps.STORE, (lop,), MemBuffer(i, out.dtype, output_st)))
return LazyOp(MetaOps.KERNEL, tuple(ast)), list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])
# *** DAG creation: decide which LazyBuffers should realize ***