rename rewrites + sink filter + bump to dagre 2.0.0 (#14966)

* bump to dagre 2.0.0

* transform to call

* cleanup names

* get kernel graph

* dagre recursion fix + better error

* add toggle to hide sink nodes

* no sink by default

* revert that

* only hide final sinks

* lol
This commit is contained in:
George Hotz 2026-02-23 22:47:22 +08:00 committed by GitHub
commit 806581f807
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 69 additions and 842 deletions

2
.gitignore vendored
View file

@ -66,3 +66,5 @@ target
.mypy_cache
mutants
.mutmut-cache
dagre/
graphlib/

View file

@ -10,7 +10,7 @@ Directories are listed in order of how they are processed.
Group UOps into kernels.
::: tinygrad.schedule.rangeify.get_rangeify
::: tinygrad.schedule.rangeify.get_kernel_graph
options:
members: false
show_labels: false

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, profile_matches
from tinygrad.dtype import ImageDType
from tinygrad.helpers import prod, DEBUG, argsort, VIZ
@ -125,7 +125,8 @@ pm_replace_buf = PatternMatcher([
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer),
])
def allocate_global_buffers(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
@profile_matches
def transform_to_call(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
# uop list is a list in the original_sink graph and we can map to the tags later
# here we build buffer map
dont_realize = {Ops.CONST, Ops.BUFFER, Ops.BIND, Ops.DEFINE_VAR, Ops.AFTER}
@ -141,5 +142,5 @@ def allocate_global_buffers(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
# here we construct the final buffer_map. this is everything that will go into the tensor map
graph_rewrite(big_sink, pm_finalize_call, ctx=ctx, name="finalize call")
ret = graph_rewrite(UOp.sink(*ctx.assigns), pm_replace_buf, ctx=ctx, name="replace bufs").call(*ctx.replacements)
if VIZ: graph_rewrite(ret, PatternMatcher([]), name="*** Call")
if VIZ: graph_rewrite(ret, PatternMatcher([]), name="View Call")
return ret, ctx.buffer_map

View file

@ -6,7 +6,7 @@ from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.allocations import allocate_global_buffers
from tinygrad.engine.allocations import transform_to_call
# **** schedule linearizer
@ -63,8 +63,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
return pre_schedule, UOp.sink(*buf_uops_list)
from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_rangeify
from tinygrad.schedule.multi import multi_pm
from tinygrad.schedule.rangeify import get_kernel_graph
from tinygrad.uop.ops import PatternMatcher, UPat
def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
@ -82,29 +81,18 @@ schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
# big_sink srcs are all the Tensors
st = time.perf_counter()
big_sink, buffer_map = allocate_global_buffers(big_sink)
big_sink, buffer_map = transform_to_call(big_sink)
# get var_vals
var_vals: dict[str, int] = {}
for i,b in enumerate(big_sink.src[1:]):
if b.op is Ops.BIND:
nm = b.src[0].expr
val = b.src[1].arg
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
var_vals[nm] = val
big_sink_cache = big_sink.src[0]
sched_cache_key = big_sink_cache.key
if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
function = big_sink.src[0]
if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None:
if SPEC: type_verify(big_sink, tensor_spec)
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm", rewrite_into_calls=True)
pre_schedule, buf_uops_sink = create_schedule(get_rangeify(big_sink_cache))
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink)
pre_schedule, buf_uops_sink = create_schedule(get_kernel_graph(function))
if SCACHE: schedule_cache[function.key] = (pre_schedule, buf_uops_sink)
else:
# schedule cache hit
pre_schedule, buf_uops_sink = sc_ret
# it's a call that we late apply
buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="apply buffers")
buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
# add bufs to pre_schedule
schedule: list[ExecItem] = []
@ -133,7 +121,18 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
else:
frm = None
print(f"scheduled {len(schedule):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {function.key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
# vars used in the schedule
used_vars = set().union(*[{v.expr for v in si.ast.variables()} for si in schedule])
return buffer_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}
# get var_vals
var_vals: dict[str, int] = {}
for i,b in enumerate(big_sink.src[1:]):
if b.op is Ops.BIND:
nm = b.src[0].expr
if nm not in used_vars: continue
val = b.src[1].arg
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
var_vals[nm] = val
return buffer_map, schedule, var_vals

View file

@ -2,13 +2,14 @@ from dataclasses import dataclass, field, replace
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
from tinygrad.helpers import PCONTIG, partition, get_single_element
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
from tinygrad.codegen.opt import Opt
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
from tinygrad.schedule.multi import multi_pm
# creation can recurse a lot
import sys
@ -478,9 +479,10 @@ split_kernels = PatternMatcher([
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
])
def get_rangeify(sink:UOp) -> UOp:
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
tsink = graph_rewrite(sink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites")
@profile_matches
def get_kernel_graph(sink:UOp) -> UOp:
tsink = graph_rewrite(sink, multi_pm, name="multi_pm", rewrite_into_calls=True)
tsink = graph_rewrite(tsink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites")
# convert movement ops to ranges
tsink, rctx = run_rangeify(tsink, bool(DEBUG_RANGEIFY))

View file

@ -259,7 +259,7 @@ class Tensor(OpMixin):
# this is where the schedule cache should go
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink)
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
_apply_map_to_tensors(becomes_map, name="buffers")
return schedule, var_vals
def schedule(self, *lst:Tensor) -> list[ExecItem]:

View file

@ -859,7 +859,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# TODO: this should replace placeholder
@staticmethod
def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None, vmin_vmax:tuple[PyConst, PyConst]|None=None, name="None"):
def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None, vmin_vmax:tuple[PyConst, PyConst]|None=None, name=None):
src: tuple[UOp, ...] = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + \
(UOp(Ops.NOOP) if device is None else UOp(Ops.DEVICE, arg=device),)
if vmin_vmax is not None: src += (UOp.const(dtype, vmin_vmax[0]), UOp.const(dtype.scalar(), vmin_vmax[1]))

File diff suppressed because one or more lines are too long

View file

@ -131,10 +131,15 @@ function renderDag(layoutSpec, { recenter }) {
worker = new Worker(workerUrl);
worker.postMessage(layoutSpec);
worker.onmessage = (e) => {
if (e.data.error) {
updateProgress(Status.ERR, "Error in graph layout:\n"+e.data.error);
return;
}
const data = e.data.result;
displaySelection("#graph");
updateProgress(Status.COMPLETE);
drawGraph(e.data);
addTags(d3.select("#edge-labels").selectAll("g").data(e.data.edges).join("g").attr("transform", (e) => {
drawGraph(data);
addTags(d3.select("#edge-labels").selectAll("g").data(data.edges).join("g").attr("transform", (e) => {
// get a point near the end
const [p1, p2] = e.value.points.slice(-2);
const dx = p2.x-p1.x;
@ -751,6 +756,8 @@ const createToggle = (id, text) => {
}
const showIndexing = createToggle("show-indexing", "Show indexing (r)");
const showCallSrc = createToggle("show-call-src", "Show CALL src (c)");
const showSink = createToggle("show-sink", "Show SINK (s)");
showSink.toggle.checked = false;
const showGraph = createToggle("show-graph", "Show graph (g)");
showGraph.toggle.onchange = () => displaySelection(rect("#graph").width > 0 ? "#custom" : "#graph");
@ -900,13 +907,14 @@ async function main() {
// ** center graph
const data = ret[currentRewrite];
const render = (opts) => renderDag({ data, opts }, { recenter:currentRewrite === 0 });
const getOpts = () => ({ showIndexing:showIndexing.toggle.checked, showCallSrc:showCallSrc.toggle.checked });
const getOpts = () => ({ showIndexing:showIndexing.toggle.checked, showCallSrc:showCallSrc.toggle.checked, showSink:showSink.toggle.checked });
render(getOpts());
showIndexing.toggle.onchange = () => render(getOpts());
showCallSrc.toggle.onchange = () => render(getOpts());
showSink.toggle.onchange = () => render(getOpts());
// ** right sidebar metadata
metadata.innerHTML = "";
if (ckey.includes("rewrites")) metadata.append(showIndexing.label, showCallSrc.label);
if (ckey.includes("rewrites")) metadata.append(showIndexing.label, showCallSrc.label, showSink.label);
if (step.code_line != null) metadata.appendChild(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }));
if (step.trace) {
const trace = d3.create("pre").append("code").classed("hljs", true);
@ -1037,6 +1045,8 @@ document.addEventListener("keydown", (event) => {
if (event.key === "r") showIndexing.toggle.click();
// c key toggles CALL src
if (event.key === "c") showCallSrc.toggle.click();
// s key toggles SINK
if (event.key === "s") showSink.toggle.click();
// g key toggles graph
if (event.key === "g") showGraph.toggle.click();
});

View file

@ -5,11 +5,16 @@ const canvas = new OffscreenCanvas(0, 0);
const ctx = canvas.getContext("2d");
onmessage = (e) => {
const { data, opts } = e.data;
const g = new dagre.graphlib.Graph({ compound: true }).setDefaultEdgeLabel(function() { return {}; });
(data.blocks != null ? layoutCfg : layoutUOp)(g, data, opts);
postMessage(dagre.graphlib.json.write(g));
self.close();
try {
const { data, opts } = e.data;
const g = new dagre.graphlib.Graph({ compound: true }).setDefaultEdgeLabel(function() { return {}; });
(data.blocks != null ? layoutCfg : layoutUOp)(g, data, opts);
postMessage({result: dagre.graphlib.json.write(g)});
self.close();
} catch (err) {
postMessage({error: err.stack || err.message || String(err)});
self.close();
}
}
const layoutCfg = (g, { blocks, paths, pc_tokens }) => {
@ -56,6 +61,12 @@ const layoutUOp = (g, { graph, change }, opts) => {
if (change?.includes(parseInt(k))) g.setParent(k, "overlay");
}
// optionally hide nodes from the layout
if (!opts.showSink) {
for (const n of g.nodes()) {
const node = g.node(n);
if ((node.label === "SINK" || node.label.startsWith("SINK\n")) && (g.successors(n) || []).length === 0) g.removeNode(n);
}
}
if (!opts.showIndexing) {
for (const n of g.nodes()) {
const node = g.node(n);