Compare commits

...

3 commits

Author SHA1 Message Date
George Hotz
8395071f77 recursive stuff works 2026-02-24 15:15:36 +08:00
George Hotz
de3e901b71 works but bad 2026-02-24 14:40:39 +08:00
George Hotz
ae2410e10e add callify method 2026-02-24 11:44:33 +08:00
3 changed files with 206 additions and 40 deletions

120
test/unit/test_callify.py Normal file
View file

@ -0,0 +1,120 @@
import unittest
from tinygrad import Tensor, dtypes
class TestCallify(unittest.TestCase):
def test_basic(self):
a = Tensor([1.,2,3])
b = Tensor([4.,5,6])
out = a + b
out.callify()
self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0])
def test_const(self):
out = Tensor(2.0) + Tensor(3.0)
out.callify()
self.assertEqual(out.item(), 5.0)
def test_sum(self):
out = Tensor.ones(16).contiguous().sum()
out.callify()
self.assertEqual(out.item(), 16.0)
def test_multi_output(self):
a = Tensor([1.,2,3])
b = Tensor([4.,5,6])
c = a + b
d = a * b
c.callify(d)
self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0])
self.assertListEqual(d.tolist(), [4.0, 10.0, 18.0])
def test_two_callify_independent(self):
a = Tensor([1.,2,3])
b = Tensor([4.,5,6])
c = a + b
c.callify()
d = Tensor([10.,20,30])
e = Tensor([1.,1,1])
f = d - e
f.callify()
self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0])
self.assertListEqual(f.tolist(), [9.0, 19.0, 29.0])
def test_two_callify_shared_input(self):
a = Tensor([1.,2,3]).contiguous().realize()
b = a + 1
b.callify()
c = a * 2
c.callify()
self.assertListEqual(b.tolist(), [2.0, 3.0, 4.0])
self.assertListEqual(c.tolist(), [2.0, 4.0, 6.0])
def test_chained_callify(self):
a = Tensor([1.,2,3])
b = a + 1
b.callify()
b.realize()
c = b + 1
c.callify()
self.assertListEqual(c.tolist(), [3.0, 4.0, 5.0])
def test_gemm(self):
a = Tensor.ones(8, 8).contiguous()
b = Tensor.eye(8).contiguous()
out = a @ b
out.callify()
lst = out.tolist()
for y in range(8):
for x in range(8):
self.assertEqual(lst[y][x], 1.0)
def test_int_dtype(self):
a = Tensor([1,2,3], dtype=dtypes.int)
b = Tensor([4,5,6], dtype=dtypes.int)
out = a + b
out.callify()
self.assertListEqual(out.tolist(), [5, 7, 9])
def test_callify_then_schedule(self):
a = Tensor([1.,2,3])
b = Tensor([4.,5,6])
out = a + b
out.callify()
schedule = out.schedule()
self.assertGreater(len(schedule), 0)
self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0])
def test_reduce(self):
out = Tensor([1.,2,3,4]).sum()
out.callify()
self.assertEqual(out.item(), 10.0)
def test_multiple_ops(self):
a = Tensor([1.,2,3])
b = Tensor([4.,5,6])
out = (a + b) * (a - b)
out.callify()
self.assertListEqual(out.tolist(), [-15.0, -21.0, -27.0])
def test_double_callify(self):
a = Tensor([1.,2,3])
b = Tensor([4.,5,6])
out = a + b
out.callify()
out.callify()
self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0])
def test_double_callify_multi_output(self):
a = Tensor([1.,2,3])
b = Tensor([4.,5,6])
c = a + b
d = a * b
c.callify(d)
c.callify(d)
self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0])
self.assertListEqual(d.tolist(), [4.0, 10.0, 18.0])
if __name__ == "__main__":
unittest.main()

View file

@ -1,12 +1,11 @@
import time, inspect
from typing import cast
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink
from tinygrad.uop.ops import UOp, Ops, KernelInfo, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.allocations import transform_to_call
# **** schedule linearizer
@ -68,43 +67,48 @@ def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
return ret
pm_post_sched_cache = PatternMatcher([
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]),
# tag=True prevents re-matching after replacement (needed when PARAMs replace with PARAMs in nested callify)
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg].replace(tag=True) if x.tag is None else None),
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
])
schedule_cache: dict[bytes, UOp] = {}
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
# big_sink srcs are all the Tensors
st = time.perf_counter()
big_sink, buffer_map = transform_to_call(big_sink)
function = big_sink.src[0]
if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None:
if SPEC: type_verify(big_sink, tensor_spec)
def _resolve_params(linear:UOp, params:tuple[UOp, ...]) -> UOp:
"""Replace PARAMs in a LINEAR with the given params (BUFFERs or outer PARAMs), also handling LUNIQUE BUFFERs."""
from tinygrad.uop.ops import _remove_all_tags
linear = graph_rewrite(linear, pm_post_sched_cache, ctx=({}, params), name="params to buffers")
return graph_rewrite(linear, _remove_all_tags, name="remove tags")
def rewrite_call_to_linear(ctx:list, call:UOp) -> UOp|None:
"""Rewrite rule: CALL(SINK, *params) -> LINEAR(...) with caching. Only matches top-level CALLs from transform_to_call."""
function = call.src[0]
if function.op is not Ops.SINK or isinstance(function.arg, KernelInfo): return None
# recursively schedule any nested CALLs inside the function (from nested callify)
inner_start = len(ctx)
function = graph_rewrite(function, pm_schedule, ctx=ctx, name="schedule nested calls")
if not SCACHE or (linear:=schedule_cache.get(function.key, None)) is None:
if SPEC: type_verify(call.replace(src=(function,)+call.src[1:]), tensor_spec)
linear = create_schedule(get_kernel_graph(function))
if SCACHE: schedule_cache[function.key] = linear
else:
# schedule cache hit
linear = sc_ret
# late apply params to buffers (tag=True prevents PARAM->PARAM cycles in nested callify)
linear = _resolve_params(linear, call.src[1:])
# resolve remaining PARAMs in inner LINEARs from nested CALLs using this call's params
for i in range(inner_start, len(ctx)):
inner_call, inner_linear = ctx[i]
ctx[i] = (inner_call, _resolve_params(inner_linear, call.src[1:]))
ctx.append((call, linear))
return linear
# it's a call that we late apply
linear = graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
pm_schedule = PatternMatcher([
(UPat(Ops.CALL, name="call"), rewrite_call_to_linear),
# strip AFTER(buf, LINEAR) -> buf after scheduling
(UPat(Ops.AFTER, src=(UPat(name="buf"), UPat(Ops.LINEAR))), lambda ctx,buf: buf),
])
# vars used in the schedule
used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src])
# get var_vals
var_vals: dict[str, int] = {}
for b in big_sink.src[1:]:
if b.op is Ops.BIND:
nm = b.src[0].expr
if nm not in used_vars: continue
val = b.src[1].arg
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
var_vals[nm] = val
# convert LINEAR to ExecItems
def linear_to_schedule(linear:UOp) -> list[ExecItem]:
"""Convert a LINEAR UOp to a list of ExecItems."""
schedule: list[ExecItem] = []
for si in linear.src:
ast, buf_uops = si.src[0], si.src[1:]
@ -122,6 +126,34 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
schedule.append(ExecItem(ast, list(bufs), metadata, {dnums[0].expr:j} if len(dnums) else {}))
else:
schedule.append(ExecItem(ast, list(ubufs), metadata))
return schedule
# strip AFTER(buf, LINEAR) -> buf, used by _apply_map_to_tensors to clean up scope tensors after scheduling
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[list[UOp], list[ExecItem], dict[str, int]]:
st = time.perf_counter()
# rewrite CALLs to LINEARs and strip AFTERs
call_linear_pairs: list[tuple[UOp, UOp]] = []
graph_rewrite(big_sink, pm_schedule, ctx=call_linear_pairs, name="schedule calls")
# collect ExecItems from all LINEARs
schedule: list[ExecItem] = []
for _, linear in call_linear_pairs:
schedule.extend(linear_to_schedule(linear))
# get var_vals from CALL params
used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for _, linear in call_linear_pairs for si in linear.src])
var_vals: dict[str, int] = {}
for call, _ in call_linear_pairs:
for b in call.src[1:]:
if b.op is Ops.BIND:
nm = b.src[0].expr
if nm not in used_vars: continue
val = b.src[1].arg
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
var_vals[nm] = val
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
@ -131,7 +163,6 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
else:
frm = None
print(f"scheduled {len(schedule):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {function.key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
return buffer_map, schedule, var_vals
return [call for call, _ in call_linear_pairs], schedule, var_vals

View file

@ -13,9 +13,11 @@ from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin
from tinygrad.mixin.movement import _align_left
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable
from tinygrad.uop.ops import PatternMatcher, UPat
from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars
from tinygrad.device import Device, Buffer
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.allocations import transform_to_call
# TODO: this should be the only usage of Device
def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
@ -25,7 +27,8 @@ def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
all_tensors: dict[weakref.ref[Tensor], None] = {}
_pending_assigns: dict[UOp, list[UOp]] = {} # buffer_uop -> [assign_uops in insertion order]
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
_pm_strip_after_noop = PatternMatcher([(UPat(Ops.AFTER, src=(UPat(name="buf"), UPat(Ops.NOOP))), lambda ctx,buf: buf)])
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str, extra_pm:PatternMatcher|None=None) -> None:
with cpu_profile(TracingKey(name), "TINY"):
# get tensors in scope
in_scope: dict[UOp, bool] = {}
@ -34,7 +37,7 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
# get all Tensors and apply the map
sink = UOp.sink(*[t.uop for t in scope_tensors])
new_sink = sink.substitute(applied_map, name=f"substitute {name}")
new_sink = sink.substitute(applied_map, name=f"substitute {name}", extra_pm=extra_pm)
# set the relevant uop to the realized UOps
for t,s,ns in zip(scope_tensors, sink.src, new_sink.src):
@ -249,17 +252,25 @@ class Tensor(OpMixin):
"""
return [Tensor(u, device=u.device) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
def callify(self, *lst:Tensor) -> Tensor:
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
big_sink, buffer_map = transform_to_call(big_sink)
_apply_map_to_tensors({x:y.after(big_sink) for x,y in buffer_map.items()}, name="callify")
return self
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
"""
Creates the schedule needed to realize these Tensor(s), with Variables.
NOTE: A Tensor can only be scheduled once.
"""
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
# this is where the schedule cache should go
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink)
_apply_map_to_tensors(becomes_map, name="buffers")
# collect existing CALLs before callify (so we can clean them up in other tensors that share them)
pre_calls = {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.CALL}
self.callify(*lst)
calls, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(*[x.uop for x in (self,)+lst]))
# replace scheduled CALLs with NOOP so AFTER(buf, CALL) -> AFTER(buf, NOOP) -> buf in scope tensors
# include pre-existing CALLs too (they were reconstructed inside callify, but other tensors still reference the originals)
_apply_map_to_tensors({c:UOp(Ops.NOOP) for c in set(calls) | pre_calls}, name="buffers", extra_pm=_pm_strip_after_noop)
return schedule, var_vals
def schedule(self, *lst:Tensor) -> list[ExecItem]:
@ -278,8 +289,12 @@ class Tensor(OpMixin):
# recursively realize pending assigns that this assign's value depends on
for u in assign_uop.toposort():
if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u)
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop))
_apply_map_to_tensors(becomes_map, name="Apply Pending Assign")
sink = UOp.sink(assign_uop)
call, buffer_map = transform_to_call(sink)
callified_sink = UOp.sink(*[buffer_map.get(s, s).after(call) for s in sink.src])
calls, schedule, var_vals = complete_create_schedule_with_vars(callified_sink)
becomes_map = {**buffer_map, **{c:UOp(Ops.NOOP) for c in calls}}
_apply_map_to_tensors(becomes_map, name="Apply Pending Assign", extra_pm=_pm_strip_after_noop)
run_schedule(schedule, var_vals, do_update_stats=do_update_stats)
# update remaining pending assigns so they reference realized buffers instead of stale lazy graphs
if becomes_map: