Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
0b389b23c8 remove custom kernel contiguous 2026-04-30 13:18:32 -07:00
2 changed files with 3 additions and 5 deletions

View file

@ -321,7 +321,6 @@ class TestCustomKernel(unittest.TestCase):
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertEqual(z.tolist(), x.add(2).tolist())
@unittest.expectedFailure
def test_custom_kernel_sched_copy(self): self.test_custom_kernel_sched(use_custom=True)
class TestUOpReduce(unittest.TestCase):

View file

@ -953,10 +953,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
body = self if self.op is Ops.TUPLE else UOp.maketuple(self)
return UOp(Ops.FUNCTION, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
kernel = fxn(*placeholders).call(*contig_srcs, grad_fxn=grad_fxn)
return [s.after(kernel) for s in contig_srcs]
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(srcs)]
kernel = fxn(*placeholders).call(*srcs, grad_fxn=grad_fxn)
return [s.after(kernel) for s in srcs]
@dataclass(frozen=True)
class KernelInfo: