allow after on contiguous in spec (#16169)

* feat: allow after on contiguous

* feat: add test
This commit is contained in:
wozeparrot 2026-05-12 16:11:44 -04:00 committed by GitHub
commit a613bcfc6d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 17 additions and 2 deletions

View file

@ -495,6 +495,20 @@ class TestFunctionTuple(unittest.TestCase):
Tensor.realize(a.grad)
np.testing.assert_allclose(a.grad.numpy(), [2., 2., 2., 2.])
def test_custom_kernel_precompile_further_compute(self):
def my_kernel(C:UOp, A:UOp) -> UOp:
i = UOp.range(A.shape[0], 0)
return C[i].store(A[i] * 2.0).end(i).sink(arg=KernelInfo(name="my_kernel"))
@function(precompile=True)
def f(a:Tensor):
c = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device)
c = Tensor.custom_kernel(c, a, fxn=my_kernel)[0]
return c + 1
a = Tensor([1., 2., 3., 4.]).contiguous().realize()
np.testing.assert_allclose(f(a).numpy(), [3., 5., 7., 9.])
class TestFunctionGrad(unittest.TestCase):
def test_function_grad_ops(self, precompile=False, precompile_backward=False):
N = 64

View file

@ -79,8 +79,9 @@ spec_shared = PatternMatcher([
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL),
(UPat(Ops.DEFINE_REG, src=(), name="x"), lambda x: isinstance(x.arg, int)),
# AFTER on Movement Op, PARAM, BUFFER, or another AFTER
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.PARAM, Ops.BUFFER, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.AFTER, Ops.MULTI, Ops.BITCAST})),),
# AFTER on Movement Op, PARAM, BUFFER, CONTIGUOUS, or another AFTER
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.PARAM, Ops.BUFFER, Ops.CONTIGUOUS, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.AFTER, Ops.MULTI,
Ops.BITCAST})),),
allow_any_len=True), lambda: True),
# CUSTOM (inline and non inline)