invalid clone tests and prereq [PR] (#16675)

This commit is contained in:
chenyu 2026-06-19 13:20:43 -04:00 committed by GitHub
commit 67c3e589a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 21 additions and 1 deletions

View file

@ -529,6 +529,26 @@ class TestFunctionTuple(unittest.TestCase):
np.testing.assert_allclose(g(a).numpy(), 14.0)
def test_custom_kernel_inplace_output_is_implicit(self):
# a custom_kernel output the kernel also READS (in-place add) is not write-only, so it must be captured as an input
def inplace_add(C:UOp, A:UOp) -> UOp:
i = UOp.range(A.shape[0], 0)
return C[i].store(C[i].load() + A[i]).end(i).sink(arg=KernelInfo(name="inplace_add"))
@function(precompile=True, allow_implicit=False)
def f(a:Tensor): return Tensor.custom_kernel(Tensor.empty(*a.shape, dtype=a.dtype, device=a.device), a, fxn=inplace_add)[0]
with self.assertRaisesRegex(RuntimeError, "implicit buffer"): f(Tensor([1., 2., 3., 4.]).contiguous().realize())
def test_custom_kernel_write_only_persistent_output_is_implicit(self):
# a write-only custom_kernel output that is a realized buffer must be captured
def write(C:UOp, A:UOp) -> UOp:
i = UOp.range(A.shape[0], 0)
return C[i].store(A[i] * 2.0).end(i).sink(arg=KernelInfo(name="write"))
state = Tensor([100., 200., 300., 400.], device="CPU").contiguous().realize()
@function(precompile=True, allow_implicit=True)
def f(a:Tensor): return Tensor.custom_kernel(state, a, fxn=write)[0]
f(Tensor([1., 2., 3., 4.], device="CPU").contiguous().realize()).realize()
np.testing.assert_allclose(state.numpy(), [2., 4., 6., 8.])
def test_custom_kernel_precompile_further_compute(self, multi=False, kernel_count:int=2):
devs = ("CPU:0", "CPU:1")
def my_kernel(C:UOp, A:UOp) -> UOp:

View file

@ -1151,7 +1151,7 @@ class ProgramInfo:
if u.op is Ops.PARAM and u.addrspace != AddrSpace.ALU: _globals.append(u.arg.slot)
if u.op in (Ops.STORE, Ops.LOAD):
if (idx:=u.src[0]).op in (Ops.INDEX, Ops.SHRINK) or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
if (buf:=idx.src[0].buf_uop).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
if u.op is Ops.SPECIAL:
if u.arg[0] == 'i': local_size = None
special_size = local_size if u.arg[0] == 'l' else global_size