tinygrad/tinygrad/engine/schedule.py
2025-07-21 09:29:44 -07:00

102 lines
3.9 KiB
Python

from typing import cast
from dataclasses import dataclass, field
from collections import deque, defaultdict
from tinygrad.uop.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers
from tinygrad.device import Device, Buffer, MultiBuffer
from tinygrad.helpers import Metadata, unwrap, all_same, merge_dicts
# **** ScheduleItem return type
@dataclass(frozen=True)
class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...] = ()
fixedvars: dict[Variable, int] = field(default_factory=dict)
# **** unbind Variables
def unbind_view(ctx:list[dict[Variable, int]], x:UOp):
st = unwrap(x.st).simplify()
if any(x.op is Ops.BIND for x in st.vars()):
st, var_vals = st.unbind()
ctx.append(var_vals)
return x.replace(arg=st)
return None
def unbind_bind(ctx:list[dict[Variable, int]], x:UOp):
var, val = x.unbind()
ctx.append({var.replace(src=()):val})
return var
pm_unbind = PatternMatcher([
(UPat(Ops.VIEW, name="x"), unbind_view),
(UPat(Ops.BIND, name="x"), unbind_bind),
])
# **** schedule linearizer
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int]]:
# construct the KERNEL children graph based on assigns
children: defaultdict[UOp, list[UOp]] = defaultdict(list)
in_degree: dict[UOp, int] = {}
for u in sched_sink.toposort():
if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
k = u.src[1]
in_degree.setdefault(k, 0)
for s in k.src:
if s.op is Ops.ASSIGN:
children[s.src[1]].append(k)
in_degree[k] += 1
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
for ss in s.src:
if ss.op is Ops.MSELECT: ss = ss.src[0]
if ss.op is not Ops.BUFFER:
assert ss.op is Ops.ASSIGN
children[ss.src[1]].append(k)
in_degree[k] += 1
elif s.op is Ops.BUFFER:
pass # a BUFFER is already realized, nothing to do here
else:
raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}")
# linearize KERNEL UOps into ScheduleItems in BFS order
def _heuristic(k: UOp):
if k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]): return 1000
return 0
last_heuristic: int = 0
queues: defaultdict[int, deque[UOp]] = defaultdict(deque)
last_queue: deque[UOp] = deque()
for k,v in in_degree.items():
if v == 0: queues[_heuristic(k)].append(k)
schedule: list[ScheduleItem] = []
var_vals: dict[Variable, int] = {}
while last_queue or any(queues.values()):
if not last_queue: last_heuristic, last_queue = min((it for it in queues.items() if it[1]), key=lambda x: abs(x[0]-last_heuristic))
k = last_queue.popleft()
# unbind var_vals from the kernel
local_var_vals: list[dict[Variable, int]] = []
ast = graph_rewrite(k.arg.ast, pm_unbind, ctx=local_var_vals, name="unbind vars")
var_vals = merge_dicts([var_vals, *local_var_vals])
# create subbuffers if needed
if ast.op is Ops.BUFFER_VIEW:
base = k.src[1].buf_uop.buffer
assert isinstance(base, Buffer), "base can't be MultiBuffer"
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
ubufs = tuple(s.buf_uop.buffer for s in k.src)
if any(isinstance(x, MultiBuffer) for x in ubufs):
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0]:i} if len(dnums) else {}))
else:
# ONE -> ONE
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))
for x in children[k]:
in_degree[x] -= 1
if in_degree[x] == 0: queues[_heuristic(x)].append(x)
return schedule, var_vals