mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
rename rewrites + sink filter + bump to dagre 2.0.0 (#14966)
* bump to dagre 2.0.0 * transform to call * cleanup names * get kernel graph * dagre recursion fix + better error * add toggle to hide sink nodes * no sink by default * revert that * only hide final sinks * lol
This commit is contained in:
parent
d86f1d66b5
commit
806581f807
10 changed files with 69 additions and 842 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -66,3 +66,5 @@ target
|
|||
.mypy_cache
|
||||
mutants
|
||||
.mutmut-cache
|
||||
dagre/
|
||||
graphlib/
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ Directories are listed in order of how they are processed.
|
|||
|
||||
Group UOps into kernels.
|
||||
|
||||
::: tinygrad.schedule.rangeify.get_rangeify
|
||||
::: tinygrad.schedule.rangeify.get_kernel_graph
|
||||
options:
|
||||
members: false
|
||||
show_labels: false
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, profile_matches
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import prod, DEBUG, argsort, VIZ
|
||||
|
||||
|
|
@ -125,7 +125,8 @@ pm_replace_buf = PatternMatcher([
|
|||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer),
|
||||
])
|
||||
|
||||
def allocate_global_buffers(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
|
||||
@profile_matches
|
||||
def transform_to_call(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
|
||||
# uop list is a list in the original_sink graph and we can map to the tags later
|
||||
# here we build buffer map
|
||||
dont_realize = {Ops.CONST, Ops.BUFFER, Ops.BIND, Ops.DEFINE_VAR, Ops.AFTER}
|
||||
|
|
@ -141,5 +142,5 @@ def allocate_global_buffers(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
|
|||
# here we construct the final buffer_map. this is everything that will go into the tensor map
|
||||
graph_rewrite(big_sink, pm_finalize_call, ctx=ctx, name="finalize call")
|
||||
ret = graph_rewrite(UOp.sink(*ctx.assigns), pm_replace_buf, ctx=ctx, name="replace bufs").call(*ctx.replacements)
|
||||
if VIZ: graph_rewrite(ret, PatternMatcher([]), name="*** Call")
|
||||
if VIZ: graph_rewrite(ret, PatternMatcher([]), name="View Call")
|
||||
return ret, ctx.buffer_map
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from tinygrad.uop.spec import type_verify, tensor_spec
|
|||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.allocations import allocate_global_buffers
|
||||
from tinygrad.engine.allocations import transform_to_call
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
|
|
@ -63,8 +63,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
|||
return pre_schedule, UOp.sink(*buf_uops_list)
|
||||
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_rangeify
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
from tinygrad.schedule.rangeify import get_kernel_graph
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat
|
||||
|
||||
def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
|
||||
|
|
@ -82,29 +81,18 @@ schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
|
|||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
big_sink, buffer_map = allocate_global_buffers(big_sink)
|
||||
big_sink, buffer_map = transform_to_call(big_sink)
|
||||
|
||||
# get var_vals
|
||||
var_vals: dict[str, int] = {}
|
||||
for i,b in enumerate(big_sink.src[1:]):
|
||||
if b.op is Ops.BIND:
|
||||
nm = b.src[0].expr
|
||||
val = b.src[1].arg
|
||||
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
|
||||
var_vals[nm] = val
|
||||
|
||||
big_sink_cache = big_sink.src[0]
|
||||
sched_cache_key = big_sink_cache.key
|
||||
if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
|
||||
function = big_sink.src[0]
|
||||
if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None:
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm", rewrite_into_calls=True)
|
||||
pre_schedule, buf_uops_sink = create_schedule(get_rangeify(big_sink_cache))
|
||||
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink)
|
||||
pre_schedule, buf_uops_sink = create_schedule(get_kernel_graph(function))
|
||||
if SCACHE: schedule_cache[function.key] = (pre_schedule, buf_uops_sink)
|
||||
else:
|
||||
# schedule cache hit
|
||||
pre_schedule, buf_uops_sink = sc_ret
|
||||
# it's a call that we late apply
|
||||
buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="apply buffers")
|
||||
buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
|
||||
|
||||
# add bufs to pre_schedule
|
||||
schedule: list[ExecItem] = []
|
||||
|
|
@ -133,7 +121,18 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
else:
|
||||
frm = None
|
||||
print(f"scheduled {len(schedule):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
|
||||
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {function.key.hex()[:8]}"+\
|
||||
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
|
||||
|
||||
# vars used in the schedule
|
||||
used_vars = set().union(*[{v.expr for v in si.ast.variables()} for si in schedule])
|
||||
return buffer_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}
|
||||
# get var_vals
|
||||
var_vals: dict[str, int] = {}
|
||||
for i,b in enumerate(big_sink.src[1:]):
|
||||
if b.op is Ops.BIND:
|
||||
nm = b.src[0].expr
|
||||
if nm not in used_vars: continue
|
||||
val = b.src[1].arg
|
||||
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
|
||||
var_vals[nm] = val
|
||||
return buffer_map, schedule, var_vals
|
||||
|
|
|
|||
|
|
@ -2,13 +2,14 @@ from dataclasses import dataclass, field, replace
|
|||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate
|
||||
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate, profile_matches
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
||||
from tinygrad.codegen.opt import Opt
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
|
||||
# creation can recurse a lot
|
||||
import sys
|
||||
|
|
@ -478,9 +479,10 @@ split_kernels = PatternMatcher([
|
|||
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
|
||||
])
|
||||
|
||||
def get_rangeify(sink:UOp) -> UOp:
|
||||
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
|
||||
tsink = graph_rewrite(sink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites")
|
||||
@profile_matches
|
||||
def get_kernel_graph(sink:UOp) -> UOp:
|
||||
tsink = graph_rewrite(sink, multi_pm, name="multi_pm", rewrite_into_calls=True)
|
||||
tsink = graph_rewrite(tsink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites")
|
||||
|
||||
# convert movement ops to ranges
|
||||
tsink, rctx = run_rangeify(tsink, bool(DEBUG_RANGEIFY))
|
||||
|
|
|
|||
|
|
@ -259,7 +259,7 @@ class Tensor(OpMixin):
|
|||
|
||||
# this is where the schedule cache should go
|
||||
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink)
|
||||
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
|
||||
_apply_map_to_tensors(becomes_map, name="buffers")
|
||||
return schedule, var_vals
|
||||
|
||||
def schedule(self, *lst:Tensor) -> list[ExecItem]:
|
||||
|
|
|
|||
|
|
@ -859,7 +859,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
|
||||
# TODO: this should replace placeholder
|
||||
@staticmethod
|
||||
def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None, vmin_vmax:tuple[PyConst, PyConst]|None=None, name="None"):
|
||||
def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None, vmin_vmax:tuple[PyConst, PyConst]|None=None, name=None):
|
||||
src: tuple[UOp, ...] = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + \
|
||||
(UOp(Ops.NOOP) if device is None else UOp(Ops.DEVICE, arg=device),)
|
||||
if vmin_vmax is not None: src += (UOp.const(dtype, vmin_vmax[0]), UOp.const(dtype.scalar(), vmin_vmax[1]))
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -131,10 +131,15 @@ function renderDag(layoutSpec, { recenter }) {
|
|||
worker = new Worker(workerUrl);
|
||||
worker.postMessage(layoutSpec);
|
||||
worker.onmessage = (e) => {
|
||||
if (e.data.error) {
|
||||
updateProgress(Status.ERR, "Error in graph layout:\n"+e.data.error);
|
||||
return;
|
||||
}
|
||||
const data = e.data.result;
|
||||
displaySelection("#graph");
|
||||
updateProgress(Status.COMPLETE);
|
||||
drawGraph(e.data);
|
||||
addTags(d3.select("#edge-labels").selectAll("g").data(e.data.edges).join("g").attr("transform", (e) => {
|
||||
drawGraph(data);
|
||||
addTags(d3.select("#edge-labels").selectAll("g").data(data.edges).join("g").attr("transform", (e) => {
|
||||
// get a point near the end
|
||||
const [p1, p2] = e.value.points.slice(-2);
|
||||
const dx = p2.x-p1.x;
|
||||
|
|
@ -751,6 +756,8 @@ const createToggle = (id, text) => {
|
|||
}
|
||||
const showIndexing = createToggle("show-indexing", "Show indexing (r)");
|
||||
const showCallSrc = createToggle("show-call-src", "Show CALL src (c)");
|
||||
const showSink = createToggle("show-sink", "Show SINK (s)");
|
||||
showSink.toggle.checked = false;
|
||||
const showGraph = createToggle("show-graph", "Show graph (g)");
|
||||
showGraph.toggle.onchange = () => displaySelection(rect("#graph").width > 0 ? "#custom" : "#graph");
|
||||
|
||||
|
|
@ -900,13 +907,14 @@ async function main() {
|
|||
// ** center graph
|
||||
const data = ret[currentRewrite];
|
||||
const render = (opts) => renderDag({ data, opts }, { recenter:currentRewrite === 0 });
|
||||
const getOpts = () => ({ showIndexing:showIndexing.toggle.checked, showCallSrc:showCallSrc.toggle.checked });
|
||||
const getOpts = () => ({ showIndexing:showIndexing.toggle.checked, showCallSrc:showCallSrc.toggle.checked, showSink:showSink.toggle.checked });
|
||||
render(getOpts());
|
||||
showIndexing.toggle.onchange = () => render(getOpts());
|
||||
showCallSrc.toggle.onchange = () => render(getOpts());
|
||||
showSink.toggle.onchange = () => render(getOpts());
|
||||
// ** right sidebar metadata
|
||||
metadata.innerHTML = "";
|
||||
if (ckey.includes("rewrites")) metadata.append(showIndexing.label, showCallSrc.label);
|
||||
if (ckey.includes("rewrites")) metadata.append(showIndexing.label, showCallSrc.label, showSink.label);
|
||||
if (step.code_line != null) metadata.appendChild(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }));
|
||||
if (step.trace) {
|
||||
const trace = d3.create("pre").append("code").classed("hljs", true);
|
||||
|
|
@ -1037,6 +1045,8 @@ document.addEventListener("keydown", (event) => {
|
|||
if (event.key === "r") showIndexing.toggle.click();
|
||||
// c key toggles CALL src
|
||||
if (event.key === "c") showCallSrc.toggle.click();
|
||||
// s key toggles SINK
|
||||
if (event.key === "s") showSink.toggle.click();
|
||||
// g key toggles graph
|
||||
if (event.key === "g") showGraph.toggle.click();
|
||||
});
|
||||
|
|
|
|||
|
|
@ -5,11 +5,16 @@ const canvas = new OffscreenCanvas(0, 0);
|
|||
const ctx = canvas.getContext("2d");
|
||||
|
||||
onmessage = (e) => {
|
||||
const { data, opts } = e.data;
|
||||
const g = new dagre.graphlib.Graph({ compound: true }).setDefaultEdgeLabel(function() { return {}; });
|
||||
(data.blocks != null ? layoutCfg : layoutUOp)(g, data, opts);
|
||||
postMessage(dagre.graphlib.json.write(g));
|
||||
self.close();
|
||||
try {
|
||||
const { data, opts } = e.data;
|
||||
const g = new dagre.graphlib.Graph({ compound: true }).setDefaultEdgeLabel(function() { return {}; });
|
||||
(data.blocks != null ? layoutCfg : layoutUOp)(g, data, opts);
|
||||
postMessage({result: dagre.graphlib.json.write(g)});
|
||||
self.close();
|
||||
} catch (err) {
|
||||
postMessage({error: err.stack || err.message || String(err)});
|
||||
self.close();
|
||||
}
|
||||
}
|
||||
|
||||
const layoutCfg = (g, { blocks, paths, pc_tokens }) => {
|
||||
|
|
@ -56,6 +61,12 @@ const layoutUOp = (g, { graph, change }, opts) => {
|
|||
if (change?.includes(parseInt(k))) g.setParent(k, "overlay");
|
||||
}
|
||||
// optionally hide nodes from the layout
|
||||
if (!opts.showSink) {
|
||||
for (const n of g.nodes()) {
|
||||
const node = g.node(n);
|
||||
if ((node.label === "SINK" || node.label.startsWith("SINK\n")) && (g.successors(n) || []).length === 0) g.removeNode(n);
|
||||
}
|
||||
}
|
||||
if (!opts.showIndexing) {
|
||||
for (const n of g.nodes()) {
|
||||
const node = g.node(n);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue