Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
e336f3cf8c CALL with return value is FUNCTION 2026-04-16 12:36:14 +08:00
10 changed files with 43 additions and 31 deletions

View file

@ -94,13 +94,13 @@ def contiguous_mops_to_view(c:UOp, src:UOp):
if (view := _make_buffer_view(src)) is None: return None if (view := _make_buffer_view(src)) is None: return None
return view.contiguous(tag=c.tag) return view.contiguous(tag=c.tag)
def transform_precompiled_call(c:UOp) -> UOp|None: def transform_precompiled_function(c:UOp) -> UOp|None:
if not c.arg.precompile: return None if not c.arg.precompile: return None
if c.src[0].op is Ops.SINK: return None if c.src[0].op is Ops.SINK: return None
assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled call, got {c.src[0].op}" assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled function, got {c.src[0].op}"
input_buffers = tuple(x.contiguous() if x.op not in {Ops.AFTER, Ops.BIND} else x for x in c.src[1:]) input_buffers = tuple(x.contiguous() if x.op not in {Ops.AFTER, Ops.BIND} else x for x in c.src[1:])
# add the outputs to the call # add the outputs to the function
srcs = c.src[0].src srcs = c.src[0].src
resolved = [c.gettuple(i) for i in range(len(srcs))] resolved = [c.gettuple(i) for i in range(len(srcs))]
outs = tuple(_buffer_like(r) for r in resolved) outs = tuple(_buffer_like(r) for r in resolved)
@ -108,7 +108,7 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
fxn = UOp.sink(*[t.after(t.store(s)) for t,s in zip(targets, srcs)]) fxn = UOp.sink(*[t.after(t.store(s)) for t,s in zip(targets, srcs)])
# create the new thing for the big graph # create the new thing for the big graph
new_call = c.replace(src=(fxn, *input_buffers, *outs), tag=None) new_call = c.replace(op=Ops.CALL, src=(fxn, *input_buffers, *outs), tag=None)
rets = tuple(o.after(new_call) for o in outs) rets = tuple(o.after(new_call) for o in outs)
# if the CALL has symbolic shapes, shrink the max-sized output to the actual symbolic shape # if the CALL has symbolic shapes, shrink the max-sized output to the actual symbolic shape
@ -119,8 +119,8 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
# NOTE: adding rules to here is bad. these all need to run before the schedule cache # NOTE: adding rules to here is bad. these all need to run before the schedule cache
pm_early_transform_tensor_graph = PatternMatcher([ pm_early_transform_tensor_graph = PatternMatcher([
# transform precompiled CALLs # transform precompiled FUNCTIONs -> CALLs
(UPat(Ops.CALL, name="c"), transform_precompiled_call), (UPat(Ops.FUNCTION, name="c"), transform_precompiled_function),
# resolve TUPLE+GETTUPLE (for precompiled calls) # resolve TUPLE+GETTUPLE (for precompiled calls)
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]), (UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),

View file

@ -95,15 +95,15 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
if t0 not in grads or grads[t0].op is Ops.NOOP: continue if t0 not in grads or grads[t0].op is Ops.NOOP: continue
# GETTUPLE: accumulate gradient into a TUPLE UOp on the CALL, process when we hit the CALL # GETTUPLE: accumulate gradient into a TUPLE UOp on the CALL, process when we hit the CALL
if t0.op is Ops.GETTUPLE: if t0.op is Ops.GETTUPLE:
k = t0.src[0] # the CALL k = t0.src[0] # the FUNCTION
assert k.op is Ops.CALL and k.src[0].op is Ops.TUPLE assert k.op is Ops.FUNCTION and k.src[0].op is Ops.TUPLE
n_outputs = len(k.src[0].src) n_outputs = len(k.src[0].src)
prev = grads[k].src if k in grads else tuple(UOp(Ops.NOOP) for _ in range(n_outputs)) prev = grads[k].src if k in grads else tuple(UOp(Ops.NOOP) for _ in range(n_outputs))
grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg and prev[i].op is not Ops.NOOP else grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg and prev[i].op is not Ops.NOOP else
grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs))) grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs)))
continue continue
# CALL: pass needed param set so backward only computes required gradients # FUNCTION: pass needed param set so backward only computes required gradients
if t0.op is Ops.CALL: if t0.op is Ops.FUNCTION:
needed = {i for i, arg in enumerate(t0.src[1:]) if arg in targets or in_target_path.get(arg, False)} needed = {i for i, arg in enumerate(t0.src[1:]) if arg in targets or in_target_path.get(arg, False)}
lgrads:tuple[UOp|None, ...]|None = call_gradient(grads[t0], t0, needed) lgrads:tuple[UOp|None, ...]|None = call_gradient(grads[t0], t0, needed)
else: else:

View file

@ -1,5 +1,5 @@
from tinygrad.helpers import all_same, prod, getenv, ALLREDUCE_CAST from tinygrad.helpers import all_same, prod, getenv, ALLREDUCE_CAST
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite, should_resolve_call from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite, KernelInfo
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from tinygrad.schedule.allreduce import handle_allreduce from tinygrad.schedule.allreduce import handle_allreduce
@ -116,6 +116,15 @@ def store_after_multi(dest:UOp, src:UOp): return dest.after(dest.store(src.src[0
def passthrough_multi(root:UOp, multi:UOp): def passthrough_multi(root:UOp, multi:UOp):
return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis) return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis)
# TODO: this is all junk
def should_resolve_call(c:UOp) -> bool:
# don't resolve real kernel calls, sink or program
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return False
if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.CUSTOM_FUNCTION}: return False
if c.arg.precompile: return False
return True
def rewrite_into_call(call:UOp): def rewrite_into_call(call:UOp):
if not should_resolve_call(call): return None if not should_resolve_call(call): return None
new_body = graph_rewrite(call.src[0], multi_pm, name="subcall") new_body = graph_rewrite(call.src[0], multi_pm, name="subcall")

View file

@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace
import itertools import itertools
from tinygrad.dtype import dtypes, PtrDType, AddrSpace, Invalid from tinygrad.dtype import dtypes, PtrDType, AddrSpace, Invalid
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, should_resolve_call, identity_element from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches, identity_element
from tinygrad.uop.symbolic import symbolic from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element
@ -126,8 +126,7 @@ mop_cleanup = PatternMatcher([
]) ])
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ]) pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None: def resolve_function(c:UOp, allow_param_mismatch=True) -> UOp|None:
if not should_resolve_call(c): return None
params: list[UOp] = [] params: list[UOp] = []
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params, name="gather params") graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params, name="gather params")
params = sorted(params, key=lambda x: x.arg) params = sorted(params, key=lambda x: x.arg)
@ -150,8 +149,8 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
(UPat(Ops.COPY, src=(UPat.var("s"), UPat.var("d"))), (UPat(Ops.COPY, src=(UPat.var("s"), UPat.var("d"))),
lambda s,d: s.substitute({UOp(Ops.DEVICE, arg=s.device):d}) if s.base.op is Ops.CONST else None), lambda s,d: s.substitute({UOp(Ops.DEVICE, arg=s.device):d}) if s.base.op is Ops.CONST else None),
# resolve calls # resolve functions
(UPat(Ops.CALL, name="c"), resolve_call), (UPat(Ops.FUNCTION, name="c"), resolve_function),
# resolve TUPLE+GETTUPLE # resolve TUPLE+GETTUPLE
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]), (UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),

View file

@ -222,7 +222,7 @@ class Tensor(OpMixin):
param = UOp.param(slot, self.dtype, self.shape, self.device) param = UOp.param(slot, self.dtype, self.shape, self.device)
return Tensor(param) return Tensor(param)
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor: def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
fret = (fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn) fret = (fxn.uop if isinstance(fxn, Tensor) else fxn).function(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn)
return Tensor(fret.gettuple(0)) return Tensor(fret.gettuple(0))
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]: def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:

View file

@ -26,7 +26,7 @@ class Ops(FastEnum):
# uops that aren't rendered # uops that aren't rendered
NOOP = auto(); REWRITE_ERROR = auto() NOOP = auto(); REWRITE_ERROR = auto()
PARAM = auto(); CALL = auto() PARAM = auto(); CALL = auto(); FUNCTION = auto()
# renderer # renderer
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled # LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled

View file

@ -215,17 +215,17 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# late ops don't have shape # late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \ Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE: Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
return None return None
case Ops.GETTUPLE: case Ops.GETTUPLE:
# GETTUPLE extracts from a TUPLE (possibly through a CALL) # GETTUPLE extracts from a TUPLE (possibly through a FUNCTION)
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.CALL else self.src[0] in_tuple = self.src[0].src[0] if self.src[0].op is Ops.FUNCTION else self.src[0]
assert in_tuple.op is Ops.TUPLE assert in_tuple.op is Ops.TUPLE
inner_shape = in_tuple.src[self.arg]._shape inner_shape = in_tuple.src[self.arg]._shape
if inner_shape is None: return None if inner_shape is None: return None
# if through a CALL, substitute internal PARAMs in the shape with corresponding args # if through a FUNCTION, substitute internal PARAMs in the shape with corresponding args
if self.src[0].op is Ops.CALL: if self.src[0].op is Ops.FUNCTION:
return tuple(graph_rewrite(s, _pm_resolve_params, self.src[0].src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape) return tuple(graph_rewrite(s, _pm_resolve_params, self.src[0].src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape)
return inner_shape return inner_shape
@ -262,8 +262,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END: case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
return self.src[0]._shape return self.src[0]._shape
case Ops.CALL: return None
# TODO: disallow shape changing bitcast # TODO: disallow shape changing bitcast
case Ops.BITCAST: case Ops.BITCAST:
ps = self.src[0]._shape ps = self.src[0]._shape
@ -421,8 +419,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def maketuple(*srcs:UOp): # pylint: disable=no-self-argument def maketuple(*srcs:UOp): # pylint: disable=no-self-argument
return UOp(Ops.TUPLE, dtypes.void, srcs) return UOp(Ops.TUPLE, dtypes.void, srcs)
def gettuple(self, idx:int) -> UOp: def gettuple(self, idx:int) -> UOp:
in_tuple = self.src[0] if self.op is Ops.CALL else self in_tuple = self.src[0] if self.op is Ops.FUNCTION else self
assert in_tuple.op is Ops.TUPLE, f"gettuple requires CALL or TUPLE source, got {self.op}" assert in_tuple.op is Ops.TUPLE, f"gettuple requires FUNCTION or TUPLE source, got {self.op}"
return UOp(Ops.GETTUPLE, in_tuple.src[idx].dtype, (self,), idx) return UOp(Ops.GETTUPLE, in_tuple.src[idx].dtype, (self,), idx)
def group(*srcs:UOp|None): # pylint: disable=no-self-argument def group(*srcs:UOp|None): # pylint: disable=no-self-argument
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0] if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
@ -941,6 +939,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# value-producing bodies are always wrapped in TUPLE so CALL dtype is always void # value-producing bodies are always wrapped in TUPLE so CALL dtype is always void
body = self if self.op in UOp._NO_TUPLE_WRAP else UOp.maketuple(self) body = self if self.op in UOp._NO_TUPLE_WRAP else UOp.maketuple(self)
return UOp(Ops.CALL, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward)) return UOp(Ops.CALL, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
def function(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(),
name:str|None=None, precompile:bool=False, precompile_backward:bool=False) -> UOp:
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
return UOp(Ops.FUNCTION, dtypes.void, (UOp.maketuple(self),)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]: def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs) contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)] placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]

View file

@ -134,12 +134,13 @@ _tensor_spec = PatternMatcher([
# allow CALL/PARAM/CUSTOM_FUNCTION — CALL dtype is always void # allow CALL/PARAM/CUSTOM_FUNCTION — CALL dtype is always void
(UPat(Ops.CALL, dtypes.void), lambda: True), (UPat(Ops.CALL, dtypes.void), lambda: True),
(UPat(Ops.FUNCTION, dtypes.void), lambda: True),
(UPat(Ops.PARAM), lambda: True), (UPat(Ops.PARAM), lambda: True),
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)), (UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)),
# TUPLE must have void dtype, GETTUPLE can only appear on CALL or TUPLE # TUPLE must have void dtype, GETTUPLE can only appear on CALL or TUPLE
(UPat(Ops.TUPLE, dtypes.void), lambda: True), (UPat(Ops.TUPLE, dtypes.void), lambda: True),
(UPat(Ops.GETTUPLE, src=(UPat((Ops.CALL, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)), (UPat(Ops.GETTUPLE, src=(UPat((Ops.FUNCTION, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)),
# ** for custom kernels ** # ** for custom kernels **

View file

@ -54,7 +54,7 @@ const layoutUOp = (g, { graph, change }, opts) => {
width = Math.max(width, ctx.measureText(line).width); width = Math.max(width, ctx.measureText(line).width);
height += lineHeight; height += lineHeight;
} }
const callNode = label.startsWith("CALL\n"); const callNode = label.startsWith("CALL\n") || label.startsWith("FUNCTION\n");
if (callNode) callCount++; if (callNode) callCount++;
g.setNode(k, {...rectDims(width, height), label, ref, id:k, color, tag, callNode}); g.setNode(k, {...rectDims(width, height), label, ref, id:k, color, tag, callNode});
// add edges // add edges

View file

@ -50,7 +50,8 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff", Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040", Ops.FUNCTION: "#C07788", Ops.CALL: "#00B7C8",
Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"} Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
@ -136,7 +137,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
label += f"\n({multirange_str(rngs, color=True)})" label += f"\n({multirange_str(rngs, color=True)})"
if u._shape is not None: if u._shape is not None:
label += f"\n{shape_to_str(u.shape)}" label += f"\n{shape_to_str(u.shape)}"
if u.op is Ops.CALL: if u.op in {Ops.CALL, Ops.FUNCTION}:
label += f"\n{u.src[0].key.hex()[:8]}" label += f"\n{u.src[0].key.hex()[:8]}"
if u.op in {Ops.INDEX, Ops.BUFFERIZE}: if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
if len(u.toposort()) < 30: label += f"\n{u.render()}" if len(u.toposort()) < 30: label += f"\n{u.render()}"