mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
override output shape in fused assign (#5930)
* override output shape in fused assign This makes ``` FUSE_ARANGE=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing ``` work. In general we should assert ASSIGN doesn't change shape. * merge asserts
This commit is contained in:
parent
341c394c89
commit
3d4742dd2e
2 changed files with 20 additions and 3 deletions
|
|
@ -1443,5 +1443,20 @@ class TestIndexing(unittest.TestCase):
|
|||
run_schedule(check_schedule(ref, 3))
|
||||
np.testing.assert_equal(fused.numpy(), ref.numpy())
|
||||
|
||||
def test_fuse_assign_contiguous(self):
|
||||
x = Tensor.zeros(4, 4, dtype=dtypes.int).contiguous().realize()
|
||||
a = Tensor.arange(8).reshape(4, 2)
|
||||
self.check_schedule(x.shrink((None, (0, 2))).assign(a.contiguous()), 2)
|
||||
np.testing.assert_equal(x.numpy(), [[0, 1, 0, 0], [2, 3, 0, 0], [4, 5, 0, 0], [6, 7, 0, 0]])
|
||||
|
||||
def test_assign_non_contiguous(self):
|
||||
x = Tensor.zeros(4, 4, dtype=dtypes.int).contiguous().realize()
|
||||
y = Tensor.randint(4, 2)
|
||||
a = Tensor.arange(8).reshape(4, 2)+y
|
||||
x.shrink((None, (0, 2))).assign(a).realize()
|
||||
xref = np.zeros((4, 4), dtype=int)
|
||||
xref[:, :2] = np.arange(8).reshape(4, 2)+y.numpy()
|
||||
np.testing.assert_equal(x.numpy(), xref)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
|||
|
|
@ -152,10 +152,12 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
|
|||
for i, out in enumerate(outs):
|
||||
output_st = ShapeTracker.from_shape(reduce_st(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape)
|
||||
lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, reduce_info, cache=cache)
|
||||
output_view = out.arg[0] if out.op is MetaOps.ASSIGN and out.arg else output_st
|
||||
output_view, vv = output_view.simplify().unbind()
|
||||
if out.op is MetaOps.ASSIGN and out.arg:
|
||||
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
|
||||
output_st = out.arg[0].reshape(output_st.shape)
|
||||
output_st, vv = output_st.simplify().unbind()
|
||||
if vv: var_vals.update(vv)
|
||||
ast.append(LazyOp(BufferOps.STORE, (lop,), MemBuffer(i, out.dtype, output_view)))
|
||||
ast.append(LazyOp(BufferOps.STORE, (lop,), MemBuffer(i, out.dtype, output_st)))
|
||||
return LazyOp(MetaOps.KERNEL, tuple(ast)), list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue