Compare commits

...

4 commits

Author SHA1 Message Date
George Hotz
308eb13eae codegen 2026-02-03 18:09:15 +08:00
George Hotz
1b1d81e3d5 allow kernel_spec in tensor for custom kernels 2026-02-03 18:01:07 +08:00
George Hotz
50b7b283dc allow after on param 2026-02-03 17:41:54 +08:00
George Hotz
39da624581 new custom_kernel function in tensor 2026-02-03 16:28:53 +08:00
5 changed files with 26 additions and 13 deletions

View file

@ -2,7 +2,7 @@ import time
from typing import cast
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.uop.spec import type_verify, tensor_spec, kernel_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
from tinygrad.engine.realize import ExecItem
@ -144,7 +144,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
if SPEC: type_verify(big_sink, tensor_spec)
if SPEC: type_verify(big_sink, tensor_spec+kernel_spec)
# hack to preserve metadata
graph_rewrite_map(big_sink, pm_pre_sched_cache, ctx=({}, {}), name="preserve metadata")

View file

@ -68,7 +68,15 @@ def resolve_custom_kernel(ck:UOp) -> UOp:
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
param_to_ptr = PatternMatcher([
(UPat(Ops.PARAM, name="x"), lambda x:
None if isinstance(x.dtype, PtrDType) else x.replace(src=(), dtype=x.dtype.ptr(size=x.size)).reshape(x.shape)),
])
def resolve_call(c:UOp) -> UOp:
if c.src[0].op in {Ops.SINK, Ops.PROGRAM}:
# CALL is KERNEL...sort of
return UOp(Ops.KERNEL, src=c.src[1:], arg=Kernel(graph_rewrite(c.src[0], param_to_ptr)))
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
args = c.src[1:]
# TODO: this check belongs in spec, not here

View file

@ -248,7 +248,10 @@ class Tensor(OpMixin):
This API is alpha and may change.
"""
return [Tensor(u, device=u.device) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
contig_srcs = tuple(x.contiguous() if x.uop.op is not Ops.AFTER else x for x in ((self,)+lst))
params = [x.as_param(i) for i,x in enumerate(contig_srcs)]
kernel = UOp.call(*[x.uop for x in contig_srcs], fxn=fxn(*[x.uop for x in params]), arg=grad_fxn)
return [Tensor(s.uop.after(kernel), device=s.device) for s in contig_srcs]
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
"""

View file

@ -58,6 +58,17 @@ shared_spec = PatternMatcher([
# RANGE/SPECIAL define loops, END closes them
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE))), lambda: True),
# codegen: standalone LINEAR/SOURCE/BINARY
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
])
# ***** UOp spec in the Tensor graph *****
@ -266,16 +277,6 @@ full_spec = PatternMatcher([
# in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
# codegen: standalone LINEAR/SOURCE/BINARY
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
(UPat(Ops.VECTORIZE), lambda: True),
(UPat(Ops.INDEX), lambda: True),

View file

@ -109,6 +109,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
if u.op is Ops.KERNEL:
ast_str = f"SINK{tuple(s.op for s in u.arg.ast.src)}" if u.arg.ast.op is Ops.SINK else repr(u.arg.ast.op)
argst = f"<Kernel {len(list(u.arg.ast.toposort()))} {ast_str} {[str(m) for m in u.arg.metadata]}>"
if u.op is Ops.BINARY: argst = f"<{len(u.arg)} bytes>"
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])):