mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
remove_ck_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0b389b23c8 |
2 changed files with 3 additions and 5 deletions
|
|
@ -321,7 +321,6 @@ class TestCustomKernel(unittest.TestCase):
|
|||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertEqual(z.tolist(), x.add(2).tolist())
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_custom_kernel_sched_copy(self): self.test_custom_kernel_sched(use_custom=True)
|
||||
|
||||
class TestUOpReduce(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -953,10 +953,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
body = self if self.op is Ops.TUPLE else UOp.maketuple(self)
|
||||
return UOp(Ops.FUNCTION, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
|
||||
kernel = fxn(*placeholders).call(*contig_srcs, grad_fxn=grad_fxn)
|
||||
return [s.after(kernel) for s in contig_srcs]
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(srcs)]
|
||||
kernel = fxn(*placeholders).call(*srcs, grad_fxn=grad_fxn)
|
||||
return [s.after(kernel) for s in srcs]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KernelInfo:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue