mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix transform_precompiled_call for MULTI (#16510)
based on my understanding for https://github.com/tinygrad/tinygrad/pull/16084
This commit is contained in:
parent
f11f63007d
commit
bb407d8b3c
2 changed files with 20 additions and 5 deletions
|
|
@ -505,15 +505,23 @@ class TestFunctionTuple(unittest.TestCase):
|
|||
return C[i].store(A[i] * 2.0).end(i).sink(arg=KernelInfo(name="double_kernel"))
|
||||
def double_grad(d_c:UOp, call:UOp): return (None, (Tensor(d_c) * 2.0).uop)
|
||||
|
||||
a = Tensor.full((4, 4), 7.0).contiguous().shard(devs, axis=0).realize()
|
||||
|
||||
@function(precompile=True, precompile_backward=True)
|
||||
def f(a:Tensor):
|
||||
c = Tensor(Tensor.invalids(a.shape[0]//len(devs), a.shape[1], dtype=a.dtype, device=devs).uop.multi(0), device=devs)
|
||||
return Tensor.custom_kernel(c, a, fxn=double_kernel, grad_fxn=double_grad)[0]
|
||||
|
||||
a = Tensor.full((4, 4), 7.0).contiguous().shard(devs, axis=0)
|
||||
Tensor.realize(a)
|
||||
np.testing.assert_allclose(f(a).numpy(), 14.0)
|
||||
|
||||
# g is f with empty output instead of invalids
|
||||
@function(precompile=True, allow_implicit=True)
|
||||
def g(a:Tensor):
|
||||
c = Tensor(Tensor.empty(a.shape[0]//len(devs), a.shape[1], dtype=a.dtype, device=devs).uop.multi(0), device=devs)
|
||||
return Tensor.custom_kernel(c, a, fxn=double_kernel, grad_fxn=double_grad)[0]
|
||||
|
||||
np.testing.assert_allclose(g(a).numpy(), 14.0)
|
||||
|
||||
def test_custom_kernel_precompile_further_compute(self):
|
||||
def my_kernel(C:UOp, A:UOp) -> UOp:
|
||||
i = UOp.range(A.shape[0], 0)
|
||||
|
|
|
|||
|
|
@ -98,6 +98,14 @@ def contiguous_mops_to_view(c:UOp, src:UOp):
|
|||
|
||||
return None
|
||||
|
||||
def _precompiled_output_redirect(s:UOp, t:UOp) -> UOp|None:
|
||||
# how output s lands in the caller's buffer t, or None if it must be copied into t
|
||||
# materialize straight into t
|
||||
if s.op is Ops.CONTIGUOUS: return t.after(t.store(s.src[0]))
|
||||
# rebind output storage to t
|
||||
if s.op in {Ops.BUFFER, Ops.MULTI} and s.has_buffer_identity(): return t
|
||||
return None
|
||||
|
||||
def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||
if not c.arg.precompile: return None
|
||||
assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled FUNCTION, got {c.src[0].op}"
|
||||
|
|
@ -116,9 +124,8 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
|
|||
while s.op is Ops.AFTER:
|
||||
after_deps.extend(s.src[1:])
|
||||
s = s.src[0]
|
||||
base = s.base
|
||||
if base.op in {Ops.CONTIGUOUS, Ops.BUFFER} and base.shape == t.shape and base not in subs:
|
||||
subs[base] = t.after(t.store(base.src[0])) if base.op is Ops.CONTIGUOUS else t
|
||||
if (placed := _precompiled_output_redirect(s, t)) is not None and s not in subs:
|
||||
subs[s] = placed
|
||||
items.append(s.after(*after_deps) if after_deps else s)
|
||||
else:
|
||||
items.append(t.after(t.store(s), *after_deps))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue