resolve_call

This commit is contained in:
George Hotz 2026-02-05 12:57:31 +08:00
commit d83ddc05c8
6 changed files with 32 additions and 36 deletions

View file

@ -1,7 +1,7 @@
import time
from typing import cast
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, CallInfo
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
@ -29,7 +29,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
if u.op is not Ops.AFTER or u.src[1].op is Ops.RANGE: continue
k = u.src[1]
in_degree.setdefault(k, 0)
for s in k.src[0].src if k.op is Ops.END else k.src:
for s in k.src[0].src if k.op is Ops.END else k.src[1:]:
match (s := _unwrap_src(s)).op:
case Ops.AFTER:
children.setdefault(s.src[1], []).append(k)
@ -55,13 +55,13 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
while len(queue):
k = rk = queue.popleft()
if k.op is Ops.END: k = k.src[0]
assert k.op in {Ops.RANGE, Ops.KERNEL}, f"unexpected op in queue: {k.op}"
assert k.op in {Ops.RANGE, Ops.CALL}, f"unexpected op in queue: {k.op}"
if k.op is Ops.RANGE: schedule.append(k)
elif k.op is Ops.KERNEL:
ast = (kernel:=cast(Kernel, k.arg)).ast
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src if s.op is not Ops.BIND)
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
schedule.append((ast, buf_uops, kernel.metadata, {}, bound_ranges))
elif k.op is Ops.CALL:
ast = k.src[0]
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
schedule.append((ast, buf_uops, cast(CallInfo, k.arg).metadata, {}, bound_ranges))
if rk.op is Ops.END: schedule.append(rk)
for x in children.get(rk, []):
in_degree[x] -= 1

View file

@ -9,7 +9,7 @@ from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL, Ops.ENCDEC}
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.ENCDEC}
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
@ -20,7 +20,7 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
# if it's a kernel, we don't realize it
if a.src[1].op is not Ops.KERNEL: ctx[a] = None
if a.src[1].op is not Ops.CALL: ctx[a] = None
pm_generate_realize_map = PatternMatcher([
# always realize SINK src
@ -99,7 +99,7 @@ def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
if assign.src[1].op is Ops.KERNEL: return None
if assign.src[1].op is Ops.CALL: return None
to_mop = graph_rewrite(assign.src[0], PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.replace(tag=()))]))
ret = assign.replace(src=assign.src+(to_mop,))
ctx.range_map[ret] = ctx.range_map[assign]
@ -174,7 +174,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
# no ranges on kernels, they are internal
if x.op is Ops.KERNEL: continue
if x.op is Ops.CALL: continue
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], [])

View file

@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
from tinygrad.helpers import PCONTIG, partition, get_single_element, panic
@ -66,9 +66,11 @@ mop_cleanup = PatternMatcher([
def resolve_custom_kernel(ck:UOp) -> UOp:
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
return ck.arg.fxn(*placeholders).call(*ck.src)
def resolve_call(c:UOp) -> UOp:
def resolve_call(c:UOp) -> UOp|None:
# don't resolve CALLs with SINK - those are kernel calls from split_store/custom_kernel
if c.src[0].op is Ops.SINK: return None
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
args = c.src[1:]
# TODO: this check belongs in spec, not here
@ -83,7 +85,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# just removing it works...
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# resolve calls
# resolve calls (but not CALLs with SINK - those are kernel calls from split_store)
(UPat(Ops.CALL, name="c"), resolve_call),
# resolve custom kernels
@ -384,8 +386,8 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.shape).rtag(m.tag)),
# remove any RESHAPEs on KERNEL
(UPat(Ops.KERNEL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))),
# remove any RESHAPEs on CALL
(UPat(Ops.CALL, name="k"), lambda k: k.replace(src=(k.src[0],)+tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src[1:]))),
])
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
@ -519,20 +521,13 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys(), metadata=metadata)
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel[1:].src)}")
return kernel
# old
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
return kernel
split_kernels = PatternMatcher([
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
# if it's a Kernel, stop
(UPat(Ops.SINK, name="sink"), lambda sink: panic(BottomUpGate) if isinstance(sink.arg, KernelInfo) else None),
(UPat(Ops.SINK, name="sink"), lambda sink: panic(BottomUpGate()) if isinstance(sink.arg, KernelInfo) else None),
])
def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
@ -545,7 +540,7 @@ def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
return x.replace(tag=(len(ctx[0])-1,))
add_tags = PatternMatcher([
# don't tag BUFFERs, they are global
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.CALL, Ops.END,
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
])

View file

@ -364,7 +364,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@recursive_property
def trace_num(self):
num = next(ucount)
# KERNEL also has a UOp in the arg
# KERNEL has a UOp in the arg, CALL has it in src[0] so no special handling needed
arg = type(self.arg)(self.arg.ast.trace_num, self.arg.metadata) if self.op is Ops.KERNEL else self.arg
uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ())
return num
@ -1414,7 +1414,7 @@ pm_pyrender_extra = PatternMatcher([
# NOTE: you can remove pm_pyrender_extra and it'll still be correct
pm_pyrender = pm_pyrender_extra+PatternMatcher([
(UPat(Ops.KERNEL, name="u"), lambda ctx,u: f"UOp(Ops.KERNEL, src={srcs(ctx,u.src)}, arg=Kernel({ctx[u.arg.ast]}(), {u.arg.metadata}))"),
(UPat(Ops.CALL, name="u"), lambda ctx,u: f"{ctx[u.src[0]]}.call({', '.join(ctx[s] for s in u.src[1:])}, metadata={u.arg.metadata})"),
(UPat(GroupOp.All, name="u"), lambda ctx,u: f"UOp({u.op}, {u.dtype}, {srcs(ctx,u.src)}"+(f", {repr(u.arg)})" if u.arg is not None else ")")),
])
@ -1424,7 +1424,7 @@ def pyrender(ast:UOp) -> str:
cmap = consumer_map_from_toposort(lst)
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
Ops.BUFFER, Ops.COPY, Ops.KERNEL, Ops.WHERE, Ops.END, Ops.ASSIGN}
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END, Ops.ASSIGN}
to_render: set[UOp] = {ast}
for u in lst:

View file

@ -87,8 +87,9 @@ _tensor_spec = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
# CALL can attach to an AFTER to describe the compute required to realize a BUFFER
# src[0] is the function (SINK), src[1:] are buffers/bindings
(UPat(Ops.CALL, src=(UPat(Ops.SINK), UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), allow_any_len=True), lambda: True),
# ASSIGN has a target and a value. It can also optionally depend on other assigns
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
@ -249,8 +250,8 @@ full_spec = PatternMatcher([
# vectorized index
(UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True),
# linearizer: outputs + intermediate KERNELs
(UPat(Ops.KERNEL, dtype=dtypes.void), lambda: True),
# linearizer: outputs + intermediate CALLs
(UPat(Ops.CALL, dtype=dtypes.void), lambda: True),
# Invalid must have type Index
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),

View file

@ -251,7 +251,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
# only RANGE/IF/STORE/KERNEL have side effects
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
# after with 1 src is just src[0]
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
# VECTORIZE/CONST