fix transform_precompiled_call for MULTI (#16510)

based on my understanding for https://github.com/tinygrad/tinygrad/pull/16084
This commit is contained in:
chenyu 2026-06-04 20:09:58 -04:00 committed by GitHub
commit bb407d8b3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 20 additions and 5 deletions

View file

@ -505,15 +505,23 @@ class TestFunctionTuple(unittest.TestCase):
return C[i].store(A[i] * 2.0).end(i).sink(arg=KernelInfo(name="double_kernel"))
def double_grad(d_c:UOp, call:UOp): return (None, (Tensor(d_c) * 2.0).uop)
a = Tensor.full((4, 4), 7.0).contiguous().shard(devs, axis=0).realize()
@function(precompile=True, precompile_backward=True)
def f(a:Tensor):
c = Tensor(Tensor.invalids(a.shape[0]//len(devs), a.shape[1], dtype=a.dtype, device=devs).uop.multi(0), device=devs)
return Tensor.custom_kernel(c, a, fxn=double_kernel, grad_fxn=double_grad)[0]
a = Tensor.full((4, 4), 7.0).contiguous().shard(devs, axis=0)
Tensor.realize(a)
np.testing.assert_allclose(f(a).numpy(), 14.0)
# g is f with empty output instead of invalids
@function(precompile=True, allow_implicit=True)
def g(a:Tensor):
c = Tensor(Tensor.empty(a.shape[0]//len(devs), a.shape[1], dtype=a.dtype, device=devs).uop.multi(0), device=devs)
return Tensor.custom_kernel(c, a, fxn=double_kernel, grad_fxn=double_grad)[0]
np.testing.assert_allclose(g(a).numpy(), 14.0)
def test_custom_kernel_precompile_further_compute(self):
def my_kernel(C:UOp, A:UOp) -> UOp:
i = UOp.range(A.shape[0], 0)

View file

@ -98,6 +98,14 @@ def contiguous_mops_to_view(c:UOp, src:UOp):
return None
def _precompiled_output_redirect(s:UOp, t:UOp) -> UOp|None:
# how output s lands in the caller's buffer t, or None if it must be copied into t
# materialize straight into t
if s.op is Ops.CONTIGUOUS: return t.after(t.store(s.src[0]))
# rebind output storage to t
if s.op in {Ops.BUFFER, Ops.MULTI} and s.has_buffer_identity(): return t
return None
def transform_precompiled_call(c:UOp) -> UOp|None:
if not c.arg.precompile: return None
assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled FUNCTION, got {c.src[0].op}"
@ -116,9 +124,8 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
while s.op is Ops.AFTER:
after_deps.extend(s.src[1:])
s = s.src[0]
base = s.base
if base.op in {Ops.CONTIGUOUS, Ops.BUFFER} and base.shape == t.shape and base not in subs:
subs[base] = t.after(t.store(base.src[0])) if base.op is Ops.CONTIGUOUS else t
if (placed := _precompiled_output_redirect(s, t)) is not None and s not in subs:
subs[s] = placed
items.append(s.after(*after_deps) if after_deps else s)
else:
items.append(t.after(t.store(s), *after_deps))