Compare commits

...

3 commits

Author SHA1 Message Date
George Hotz
150ee9eb6d sched docs 2026-04-15 10:34:47 +08:00
George Hotz
5dcdfb0d75 callify to root 2026-04-15 10:20:38 +08:00
George Hotz
bb69860d41 move schedule into schedule 2026-04-15 10:17:31 +08:00
25 changed files with 206 additions and 206 deletions

View file

@ -38,7 +38,7 @@ optim.schedule_step() # this will step the optimizer without running realize
# The weight Tensors have been assigned to, but not yet realized. Everything is still lazy at this point # The weight Tensors have been assigned to, but not yet realized. Everything is still lazy at this point
# l1.uop and l2.uop define a computation graph # l1.uop and l2.uop define a computation graph
from tinygrad.engine.schedule import ExecItem from tinygrad.schedule import ExecItem
schedule: List[ExecItem] = Tensor.schedule(l1, l2) schedule: List[ExecItem] = Tensor.schedule(l1, l2)
print(f"The schedule contains {len(schedule)} items.") print(f"The schedule contains {len(schedule)} items.")

View file

@ -17,9 +17,9 @@ The `UOp` graph specifies the compute in terms of low level tinygrad ops. Not al
## Scheduling ## Scheduling
The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/schedule.py) converts the graph of UOps into a list of `ExecItem`. One `ExecItem` is one kernel on the GPU, and the scheduler is responsible for breaking the large compute graph into subgraphs that can fit in a kernel. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on. The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/schedule/__init__.py) converts the graph of UOps into a list of `ExecItem`. One `ExecItem` is one kernel on the GPU, and the scheduler is responsible for breaking the large compute graph into subgraphs that can fit in a kernel. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.
::: tinygrad.engine.schedule.ExecItem ::: tinygrad.schedule.ExecItem
## Lowering ## Lowering

View file

@ -5,7 +5,7 @@ from tinygrad import Device, nn, Tensor, dtypes
from train_gpt2 import GPT, GPTConfig from train_gpt2 import GPT, GPTConfig
from tinygrad.helpers import DEV, dedup, flatten, getenv, GlobalCounters, to_function_name from tinygrad.helpers import DEV, dedup, flatten, getenv, GlobalCounters, to_function_name
from tinygrad.engine.realize import get_kernel from tinygrad.engine.realize import get_kernel
from tinygrad.engine.memory import memory_planner from tinygrad.schedule.memory import memory_planner
from tinygrad.uop.ops import Ops from tinygrad.uop.ops import Ops
DEV.value = "CPU" DEV.value = "CPU"

View file

@ -4,7 +4,7 @@ from tinygrad import Tensor, GlobalCounters, dtypes, nn, Device, Variable
from tinygrad.helpers import Context, getenv, DEV from tinygrad.helpers import Context, getenv, DEV
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule
from tinygrad.engine.realize import CompiledRunner, get_program from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.engine.schedule import ExecItem from tinygrad.schedule import ExecItem
from tinygrad.renderer import Estimates from tinygrad.renderer import Estimates
from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.ptx import PTXRenderer
from test.helpers import needs_second_gpu from test.helpers import needs_second_gpu

View file

@ -7,7 +7,7 @@ from tinygrad.helpers import Context, dedup, from_mv
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from tinygrad.engine.jit import MultiGraphRunner from tinygrad.engine.jit import MultiGraphRunner
from tinygrad.engine.realize import BufferXfer, get_runner, CompiledRunner from tinygrad.engine.realize import BufferXfer, get_runner, CompiledRunner
from tinygrad.engine.schedule import ExecItem from tinygrad.schedule import ExecItem
from tinygrad.uop.ops import UOp, Ops from tinygrad.uop.ops import UOp, Ops
from test.helpers import needs_second_gpu from test.helpers import needs_second_gpu

View file

@ -4,7 +4,7 @@ from tinygrad import Tensor, Device
from tinygrad.helpers import get_single_element from tinygrad.helpers import get_single_element
from tinygrad.codegen.opt import Opt, OptOps from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, get_program from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.engine.schedule import ExecItem from tinygrad.schedule import ExecItem
class TestOptGemm(unittest.TestCase): class TestOptGemm(unittest.TestCase):
@classmethod @classmethod

View file

@ -6,7 +6,7 @@ from tinygrad import Tensor, Context, Device, dtypes
from tinygrad.uop.ops import Ops from tinygrad.uop.ops import Ops
from tinygrad.codegen.opt import Opt, OptOps from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, get_program from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.engine.schedule import ExecItem from tinygrad.schedule import ExecItem
N = 512 N = 512

View file

@ -8,7 +8,7 @@ from tinygrad.device import Buffer, Device
from tinygrad.uop.ops import Ops, UOp, KernelInfo, AxisType from tinygrad.uop.ops import Ops, UOp, KernelInfo, AxisType
from tinygrad.renderer.cstyle import CStyleLanguage from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner from tinygrad.engine.realize import CompiledRunner, get_program, get_runner
from tinygrad.engine.schedule import ExecItem from tinygrad.schedule import ExecItem
from tinygrad.device import is_dtype_supported from tinygrad.device import is_dtype_supported
from tinygrad.codegen.opt import Opt, OptOps from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.ptx import PTXRenderer

View file

@ -1,6 +1,6 @@
import unittest import unittest
from tinygrad import Device, Tensor from tinygrad import Device, Tensor
from tinygrad.engine.schedule import create_schedule from tinygrad.schedule import create_schedule
from tinygrad.runtime.ops_amd import AMDDevice from tinygrad.runtime.ops_amd import AMDDevice
class TestAMD(unittest.TestCase): class TestAMD(unittest.TestCase):

View file

@ -2,7 +2,7 @@ import time, unittest
from tinygrad.runtime.support.hip_comgr import compile_hip from tinygrad.runtime.support.hip_comgr import compile_hip
from tinygrad import Tensor from tinygrad import Tensor
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.engine.schedule import create_schedule from tinygrad.schedule import create_schedule
from tinygrad.codegen.opt.kernel import Kernel from tinygrad.codegen.opt.kernel import Kernel
class TestHIPCompileSpeed(unittest.TestCase): class TestHIPCompileSpeed(unittest.TestCase):

View file

@ -7,7 +7,7 @@ from tinygrad import GlobalCounters, Tensor, Device
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from tinygrad.engine.realize import capturing, run_schedule from tinygrad.engine.realize import capturing, run_schedule
from tinygrad.engine.schedule import linear_to_schedule from tinygrad.schedule import linear_to_schedule
from tinygrad.tensor import _to_np_dtype from tinygrad.tensor import _to_np_dtype
class CLCache: class CLCache:

View file

@ -1,6 +1,6 @@
import gc import gc
from tinygrad import Tensor, UOp, Device, nn from tinygrad import Tensor, UOp, Device, nn
from tinygrad.engine.schedule import schedule_cache from tinygrad.schedule import schedule_cache
from tinygrad.engine.realize import method_cache, get_program from tinygrad.engine.realize import method_cache, get_program
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
from tinygrad.uop.divandmod import fold_divmod_general from tinygrad.uop.divandmod import fold_divmod_general

View file

@ -5,7 +5,7 @@ from tinygrad.helpers import Context, getenv, from_mv
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import BufferXfer, get_runner from tinygrad.engine.realize import BufferXfer, get_runner
from tinygrad.engine.schedule import ExecItem from tinygrad.schedule import ExecItem
from tinygrad.uop.ops import UOp, Ops from tinygrad.uop.ops import UOp, Ops
from tinygrad.engine.jit import apply_graph_to_jit from tinygrad.engine.jit import apply_graph_to_jit

View file

@ -1,7 +1,7 @@
import unittest import unittest
from tinygrad import dtypes from tinygrad import dtypes
from tinygrad.uop.ops import UOp, Ops from tinygrad.uop.ops import UOp, Ops
from tinygrad.engine.memory import memory_plan_rewrite from tinygrad.schedule.memory import memory_plan_rewrite
global_map = {} global_map = {}
held_bufs: set[UOp] = set() held_bufs: set[UOp] = set()

View file

@ -1,7 +1,7 @@
import unittest import unittest
from tinygrad import Tensor, Variable, Context from tinygrad import Tensor, Variable, Context
from tinygrad.helpers import cpu_events from tinygrad.helpers import cpu_events
from tinygrad.engine.schedule import schedule_cache from tinygrad.schedule import schedule_cache
def schedule_one(): def schedule_one():
Tensor([1]).schedule() Tensor([1]).schedule()

View file

@ -2,7 +2,7 @@ import unittest
from tinygrad import Tensor, dtypes from tinygrad import Tensor, dtypes
from tinygrad.tensor import _METADATA from tinygrad.tensor import _METADATA
from tinygrad.engine.realize import capturing from tinygrad.engine.realize import capturing
from tinygrad.engine.schedule import linear_to_schedule from tinygrad.schedule import linear_to_schedule
from tinygrad.helpers import Context from tinygrad.helpers import Context
@unittest.skip("tensor metadata is no longer supported") @unittest.skip("tensor metadata is no longer supported")

View file

@ -3,7 +3,7 @@ import unittest, math, time
from tinygrad import Tensor, Device, dtypes, Context from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.uop.ops import UOp, Ops from tinygrad.uop.ops import UOp, Ops
from tinygrad.engine.realize import get_runner from tinygrad.engine.realize import get_runner
from tinygrad.engine.schedule import ExecItem from tinygrad.schedule import ExecItem
from tinygrad.engine.jit import TinyJit from tinygrad.engine.jit import TinyJit
import numpy as np import numpy as np

View file

@ -1,7 +1,7 @@
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
from tinygrad import Tensor, UOp from tinygrad import Tensor, UOp
from tinygrad.engine.schedule import schedule_cache from tinygrad.schedule import schedule_cache
from tinygrad.apps.llm import Transformer, TransformerConfig from tinygrad.apps.llm import Transformer, TransformerConfig
TEST_CONFIG = TransformerConfig(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2, TEST_CONFIG = TransformerConfig(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,

View file

@ -2,7 +2,7 @@ import unittest
import functools import functools
from tinygrad import Tensor, Variable, UOp from tinygrad import Tensor, Variable, UOp
from tinygrad.uop.ops import KernelInfo from tinygrad.uop.ops import KernelInfo
from tinygrad.engine.schedule import schedule_cache from tinygrad.schedule import schedule_cache
def custom_set0_kernel(A:UOp, num:int) -> UOp: def custom_set0_kernel(A:UOp, num:int) -> UOp:
return A[0].set(num).sink(arg=KernelInfo(f"custom_set0_{num}")) return A[0].set(num).sink(arg=KernelInfo(f"custom_set0_{num}"))

View file

@ -6,8 +6,8 @@ from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
from tinygrad.dtype import DType, dtypes from tinygrad.dtype import DType, dtypes
from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
from tinygrad.engine.memory import memory_plan_rewrite, _collect_bufs from tinygrad.schedule.memory import memory_plan_rewrite, _collect_bufs
from tinygrad.engine.schedule import linear_to_schedule from tinygrad.schedule import linear_to_schedule
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from tinygrad.schedule.rangeify import mop_cleanup from tinygrad.schedule.rangeify import mop_cleanup
from dataclasses import dataclass from dataclasses import dataclass

View file

@ -1,182 +0,0 @@
import time, inspect
from typing import cast
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink, KernelInfo
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR, flatten, BEAM, partition
from tinygrad.engine.realize import ExecItem
# **** schedule linearizer
# unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op)
def _unwrap_src(s: UOp) -> UOp:
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
return s
def _split_after(after: UOp) -> tuple[tuple[UOp, ...], tuple[UOp, ...]]:
kernels, remaining = partition(after.src[1:], lambda s: s.op in {Ops.CALL, Ops.END})
deps, remaining = partition(remaining, lambda s: s.op is Ops.AFTER)
if invalid := [s for s in remaining if s.op is not Ops.STORE]:
raise AssertionError(f"AFTER source should be CALL, END, STORE, or AFTER, not {invalid[0].op}")
return tuple(kernels), tuple(deps)
def create_schedule(sched_sink:UOp) -> UOp:
with cpu_profile(TracingKey("toposort sched_sink")):
# build kernel dependency graph: edges from producer kernel to consumer kernels
children: dict[UOp, list[UOp]] = {}
in_degree: dict[UOp, int] = {}
for u in sched_sink.toposort(gate_kernel_sink):
if u.op is not Ops.AFTER: continue
kernels, after_deps = _split_after(u)
for k in kernels:
in_degree.setdefault(k, 0)
if k.op is Ops.END: assert k.src[0].op is Ops.CALL, f"END src[0] should be KERNEL, not {k.src[0].op}"
kernel_deps = k.src[0].src[1:] if k.op is Ops.END else k.src[1:]
for s in kernel_deps + after_deps:
match (s := _unwrap_src(s)).op:
case Ops.AFTER:
for t in _split_after(s)[0]:
children.setdefault(t, []).append(k)
in_degree[k] += 1
case Ops.MSELECT | Ops.MSTACK:
for ss in s.src:
if ss.op is Ops.MSELECT: ss = ss.src[0]
if ss.op not in {Ops.BUFFER, Ops.PARAM}:
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
for t in _split_after(ss)[0]:
children.setdefault(t, []).append(k)
in_degree[k] += 1
case Ops.BUFFER | Ops.PARAM | Ops.BIND:
pass # BUFFER/PARAM is already realized, BIND is a bound variable (not a buffer dependency)
case _:
raise RuntimeError(f"input to kernel must be AFTER, BUFFER, PARAM, MSELECT, MSTACK, or BIND, not {s.op}")
with cpu_profile(TracingKey("linearize schedule")):
queue: deque[UOp] = deque(k for k,v in in_degree.items() if v == 0)
linearized: list[UOp] = []
while len(queue):
rk = queue.popleft()
if rk.op is Ops.LINEAR:
linearized.extend(rk.src)
else:
k = rk.src[0] if rk.op is Ops.END else rk
assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}"
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
linearized.append(k.src[0].call(*buf_uops, metadata=k.arg.metadata))
for x in children.get(rk, []):
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
return UOp(Ops.LINEAR, src=tuple(linearized))
def linear_to_schedule(linear:UOp) -> list[ExecItem]:
"""Convert a LINEAR UOp to a list of ExecItems."""
schedule: list[ExecItem] = []
for si in linear.src:
ast, buf_uops = si.src[0], si.src[1:]
# create subbuffers if needed
if ast.op is Ops.BUFFER_VIEW:
base = buf_uops[1].buffer
assert isinstance(base, Buffer), "base can't be MultiBuffer"
buffers[buf_uops[0]] = base.view(buf_uops[0].arg, ast.dtype, ast.arg[1]*base.dtype.itemsize)
# wrap SINK with BEAM UOp when beam search is enabled
if ast.op is Ops.SINK and BEAM >= 1: ast = UOp(Ops.BEAM, src=(ast,), arg=BEAM.value)
ubufs = [b.buffer for b in buf_uops if b.op is not Ops.BIND]
metadata = si.arg.metadata
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph":
schedule.append(ExecItem(ast, flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for b in ubufs]), metadata))
elif any(isinstance(x, MultiBuffer) for x in ubufs):
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in ast.variables() if x.expr == '_device_num']
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ExecItem(ast, list(bufs), metadata, {dnums[0].expr:j} if len(dnums) else {}))
else:
schedule.append(ExecItem(ast, cast(list[Buffer|None], ubufs), metadata))
return schedule
from tinygrad.engine.memory import memory_plan_rewrite
from tinygrad.engine.realize import capturing
from tinygrad.schedule.rangeify import get_kernel_graph
from tinygrad.helpers import CAPTURING
from tinygrad.uop.ops import PatternMatcher, UPat
def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
if (ret:=ctx[0].get(b, None)) is None: ctx[0][b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype)
return ret
pm_post_sched_cache = PatternMatcher([
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]),
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
])
pm_resolve_linear_call = PatternMatcher([
# call LINEAR is resolved here
(UPat(Ops.CALL, src=(UPat(Ops.LINEAR),), name="linear_call", allow_any_len=True), lambda linear_call:
graph_rewrite(linear_call.src[0], pm_post_sched_cache, ctx=({}, linear_call.src[1:]), walk=True, name="params to buffers")),
# LINEAR on LINEAR
(UPat(Ops.LINEAR, custom_early_reject={Ops.LINEAR}, name="x"),
lambda x: x.replace(src=tuple(flatten(x.src if x.op is Ops.LINEAR else (x,) for x in x.src)))),
])
schedule_cache: dict[bytes, UOp] = {}
# ctx is just for DEBUG on inner
def lower_sink_to_linear(function:UOp) -> UOp|None:
st = time.perf_counter()
if isinstance(function.arg, KernelInfo): return None
cache_key = function.key
if not SCACHE or (sc_ret:=schedule_cache.get(cache_key, None)) is None:
if SPEC: type_verify(function, tensor_spec)
# support recursive CALLs
linear = create_schedule(get_kernel_graph(function))
if SCACHE: schedule_cache[cache_key] = linear
else:
# schedule cache hit
linear = sc_ret
if (DEBUG >= 1 and len(linear.src) > 1) or DEBUG >= 3:
for frm in inspect.stack():
if frm.filename == "<string>": continue
if frm.filename.startswith(str(BASEDIR / "apps")): break
if not frm.filename.startswith(str(BASEDIR)) and not frm.filename.endswith("/contextlib.py"): break
else:
frm = None
print(f"scheduled {len(linear.src):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {cache_key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
return linear
pm_schedule = PatternMatcher([
(UPat(Ops.SINK, name="function"), lower_sink_to_linear),
])
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[0]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[list[ExecItem], dict[str, int]]:
# big_sink srcs are all the Tensors
linear_call = graph_rewrite(big_sink, pm_schedule, name="schedule to linear", enter_calls=True)
# this recursively resolves the linear_call and allocates buffers
linear = graph_rewrite(linear_call, pm_resolve_linear_call, name="resolve linear call")
# vars used in the schedule
used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src])
# get var_vals
var_vals: dict[str, int] = {}
for b in big_sink.src[1:]:
if b.op is Ops.BIND:
nm = b.src[0].expr
if nm not in used_vars: continue
val = b.src[1].arg
if var_vals.get(nm, val) != val: raise RuntimeError(f"bind mismatch on {nm}, {var_vals[nm]} != {val}")
var_vals[nm] = val
# jit captures this schedule, no need to execute.
if len(capturing) and CAPTURING:
capturing[0].add_linear(linear, var_vals)
return [], var_vals
held_bufs = ({b for b in linear_call.src[1:] if b.op is Ops.BUFFER} if linear_call.op is Ops.CALL else set())
linear = memory_plan_rewrite(linear, held_bufs)
# convert LINEAR to ExecItems
schedule: list[ExecItem] = linear_to_schedule(linear)
return schedule, var_vals

View file

@ -0,0 +1,182 @@
import time, inspect
from typing import cast
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink, KernelInfo
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR, flatten, BEAM, partition
from tinygrad.engine.realize import ExecItem
# **** schedule linearizer
# unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op)
def _unwrap_src(s: UOp) -> UOp:
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
return s
def _split_after(after: UOp) -> tuple[tuple[UOp, ...], tuple[UOp, ...]]:
kernels, remaining = partition(after.src[1:], lambda s: s.op in {Ops.CALL, Ops.END})
deps, remaining = partition(remaining, lambda s: s.op is Ops.AFTER)
if invalid := [s for s in remaining if s.op is not Ops.STORE]:
raise AssertionError(f"AFTER source should be CALL, END, STORE, or AFTER, not {invalid[0].op}")
return tuple(kernels), tuple(deps)
def create_schedule(sched_sink:UOp) -> UOp:
with cpu_profile(TracingKey("toposort sched_sink")):
# build kernel dependency graph: edges from producer kernel to consumer kernels
children: dict[UOp, list[UOp]] = {}
in_degree: dict[UOp, int] = {}
for u in sched_sink.toposort(gate_kernel_sink):
if u.op is not Ops.AFTER: continue
kernels, after_deps = _split_after(u)
for k in kernels:
in_degree.setdefault(k, 0)
if k.op is Ops.END: assert k.src[0].op is Ops.CALL, f"END src[0] should be KERNEL, not {k.src[0].op}"
kernel_deps = k.src[0].src[1:] if k.op is Ops.END else k.src[1:]
for s in kernel_deps + after_deps:
match (s := _unwrap_src(s)).op:
case Ops.AFTER:
for t in _split_after(s)[0]:
children.setdefault(t, []).append(k)
in_degree[k] += 1
case Ops.MSELECT | Ops.MSTACK:
for ss in s.src:
if ss.op is Ops.MSELECT: ss = ss.src[0]
if ss.op not in {Ops.BUFFER, Ops.PARAM}:
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
for t in _split_after(ss)[0]:
children.setdefault(t, []).append(k)
in_degree[k] += 1
case Ops.BUFFER | Ops.PARAM | Ops.BIND:
pass # BUFFER/PARAM is already realized, BIND is a bound variable (not a buffer dependency)
case _:
raise RuntimeError(f"input to kernel must be AFTER, BUFFER, PARAM, MSELECT, MSTACK, or BIND, not {s.op}")
with cpu_profile(TracingKey("linearize schedule")):
queue: deque[UOp] = deque(k for k,v in in_degree.items() if v == 0)
linearized: list[UOp] = []
while len(queue):
rk = queue.popleft()
if rk.op is Ops.LINEAR:
linearized.extend(rk.src)
else:
k = rk.src[0] if rk.op is Ops.END else rk
assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}"
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
linearized.append(k.src[0].call(*buf_uops, metadata=k.arg.metadata))
for x in children.get(rk, []):
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
return UOp(Ops.LINEAR, src=tuple(linearized))
def linear_to_schedule(linear:UOp) -> list[ExecItem]:
"""Convert a LINEAR UOp to a list of ExecItems."""
schedule: list[ExecItem] = []
for si in linear.src:
ast, buf_uops = si.src[0], si.src[1:]
# create subbuffers if needed
if ast.op is Ops.BUFFER_VIEW:
base = buf_uops[1].buffer
assert isinstance(base, Buffer), "base can't be MultiBuffer"
buffers[buf_uops[0]] = base.view(buf_uops[0].arg, ast.dtype, ast.arg[1]*base.dtype.itemsize)
# wrap SINK with BEAM UOp when beam search is enabled
if ast.op is Ops.SINK and BEAM >= 1: ast = UOp(Ops.BEAM, src=(ast,), arg=BEAM.value)
ubufs = [b.buffer for b in buf_uops if b.op is not Ops.BIND]
metadata = si.arg.metadata
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph":
schedule.append(ExecItem(ast, flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for b in ubufs]), metadata))
elif any(isinstance(x, MultiBuffer) for x in ubufs):
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in ast.variables() if x.expr == '_device_num']
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ExecItem(ast, list(bufs), metadata, {dnums[0].expr:j} if len(dnums) else {}))
else:
schedule.append(ExecItem(ast, cast(list[Buffer|None], ubufs), metadata))
return schedule
from tinygrad.schedule.memory import memory_plan_rewrite
from tinygrad.engine.realize import capturing
from tinygrad.schedule.rangeify import get_kernel_graph
from tinygrad.helpers import CAPTURING
from tinygrad.uop.ops import PatternMatcher, UPat
def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
if (ret:=ctx[0].get(b, None)) is None: ctx[0][b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype)
return ret
pm_post_sched_cache = PatternMatcher([
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]),
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
])
pm_resolve_linear_call = PatternMatcher([
# call LINEAR is resolved here
(UPat(Ops.CALL, src=(UPat(Ops.LINEAR),), name="linear_call", allow_any_len=True), lambda linear_call:
graph_rewrite(linear_call.src[0], pm_post_sched_cache, ctx=({}, linear_call.src[1:]), walk=True, name="params to buffers")),
# LINEAR on LINEAR
(UPat(Ops.LINEAR, custom_early_reject={Ops.LINEAR}, name="x"),
lambda x: x.replace(src=tuple(flatten(x.src if x.op is Ops.LINEAR else (x,) for x in x.src)))),
])
schedule_cache: dict[bytes, UOp] = {}
# ctx is just for DEBUG on inner
def lower_sink_to_linear(function:UOp) -> UOp|None:
st = time.perf_counter()
if isinstance(function.arg, KernelInfo): return None
cache_key = function.key
if not SCACHE or (sc_ret:=schedule_cache.get(cache_key, None)) is None:
if SPEC: type_verify(function, tensor_spec)
# support recursive CALLs
linear = create_schedule(get_kernel_graph(function))
if SCACHE: schedule_cache[cache_key] = linear
else:
# schedule cache hit
linear = sc_ret
if (DEBUG >= 1 and len(linear.src) > 1) or DEBUG >= 3:
for frm in inspect.stack():
if frm.filename == "<string>": continue
if frm.filename.startswith(str(BASEDIR / "apps")): break
if not frm.filename.startswith(str(BASEDIR)) and not frm.filename.endswith("/contextlib.py"): break
else:
frm = None
print(f"scheduled {len(linear.src):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {cache_key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
return linear
pm_schedule = PatternMatcher([
(UPat(Ops.SINK, name="function"), lower_sink_to_linear),
])
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[0]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[list[ExecItem], dict[str, int]]:
# big_sink srcs are all the Tensors
linear_call = graph_rewrite(big_sink, pm_schedule, name="schedule to linear", enter_calls=True)
# this recursively resolves the linear_call and allocates buffers
linear = graph_rewrite(linear_call, pm_resolve_linear_call, name="resolve linear call")
# vars used in the schedule
used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src])
# get var_vals
var_vals: dict[str, int] = {}
for b in big_sink.src[1:]:
if b.op is Ops.BIND:
nm = b.src[0].expr
if nm not in used_vars: continue
val = b.src[1].arg
if var_vals.get(nm, val) != val: raise RuntimeError(f"bind mismatch on {nm}, {var_vals[nm]} != {val}")
var_vals[nm] = val
# jit captures this schedule, no need to execute.
if len(capturing) and CAPTURING:
capturing[0].add_linear(linear, var_vals)
return [], var_vals
held_bufs = ({b for b in linear_call.src[1:] if b.op is Ops.BUFFER} if linear_call.op is Ops.CALL else set())
linear = memory_plan_rewrite(linear, held_bufs)
# convert LINEAR to ExecItems
schedule: list[ExecItem] = linear_to_schedule(linear)
return schedule, var_vals

View file

@ -13,10 +13,10 @@ from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin, ReductionStr from tinygrad.mixin import OpMixin, ReductionStr
from tinygrad.uop.ops import smax, UOp, Ops, sint, all_metadata, _index_to_concrete_int, sint_to_uop, Variable from tinygrad.uop.ops import smax, UOp, Ops, sint, all_metadata, _index_to_concrete_int, sint_to_uop, Variable
from tinygrad.uop.ops import _broadcast_shape from tinygrad.uop.ops import _broadcast_shape
from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars from tinygrad.schedule import ExecItem, complete_create_schedule_with_vars
from tinygrad.device import Buffer, canonicalize_device from tinygrad.device import Buffer, canonicalize_device
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule
from tinygrad.engine.callify import transform_to_call from tinygrad.callify import transform_to_call
# *** all in scope Tensors are here. this gets relevant UOps *** # *** all in scope Tensors are here. this gets relevant UOps ***