mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
allow after on contiguous in spec (#16169)
* feat: allow after on contiguous * feat: add test
This commit is contained in:
parent
7c3e3fa154
commit
a613bcfc6d
2 changed files with 17 additions and 2 deletions
|
|
@ -495,6 +495,20 @@ class TestFunctionTuple(unittest.TestCase):
|
|||
Tensor.realize(a.grad)
|
||||
np.testing.assert_allclose(a.grad.numpy(), [2., 2., 2., 2.])
|
||||
|
||||
def test_custom_kernel_precompile_further_compute(self):
|
||||
def my_kernel(C:UOp, A:UOp) -> UOp:
|
||||
i = UOp.range(A.shape[0], 0)
|
||||
return C[i].store(A[i] * 2.0).end(i).sink(arg=KernelInfo(name="my_kernel"))
|
||||
|
||||
@function(precompile=True)
|
||||
def f(a:Tensor):
|
||||
c = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device)
|
||||
c = Tensor.custom_kernel(c, a, fxn=my_kernel)[0]
|
||||
return c + 1
|
||||
|
||||
a = Tensor([1., 2., 3., 4.]).contiguous().realize()
|
||||
np.testing.assert_allclose(f(a).numpy(), [3., 5., 7., 9.])
|
||||
|
||||
class TestFunctionGrad(unittest.TestCase):
|
||||
def test_function_grad_ops(self, precompile=False, precompile_backward=False):
|
||||
N = 64
|
||||
|
|
|
|||
|
|
@ -79,8 +79,9 @@ spec_shared = PatternMatcher([
|
|||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL),
|
||||
(UPat(Ops.DEFINE_REG, src=(), name="x"), lambda x: isinstance(x.arg, int)),
|
||||
|
||||
# AFTER on Movement Op, PARAM, BUFFER, or another AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.PARAM, Ops.BUFFER, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.AFTER, Ops.MULTI, Ops.BITCAST})),),
|
||||
# AFTER on Movement Op, PARAM, BUFFER, CONTIGUOUS, or another AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.PARAM, Ops.BUFFER, Ops.CONTIGUOUS, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.AFTER, Ops.MULTI,
|
||||
Ops.BITCAST})),),
|
||||
allow_any_len=True), lambda: True),
|
||||
|
||||
# CUSTOM (inline and non inline)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue