Compare commits

...

5 commits

Author SHA1 Message Date
George Hotz
e7c8aaed31 go 2026-02-10 11:55:02 +08:00
George Hotz
a65d9fea74 speed 2026-02-09 23:40:06 +08:00
George Hotz
29b2afa0cb fix cycle 2026-02-09 09:50:55 +08:00
George Hotz
d70e255c89 speed + deterministic 2026-02-09 09:39:02 +08:00
George Hotz
9e46535ad3 play with some basic egraph stuff 2026-02-09 09:04:10 +08:00
4 changed files with 751 additions and 14 deletions

502
test/null/test_egraph.py Normal file
View file

@ -0,0 +1,502 @@
import unittest
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import Ops, UOp, GroupOp, PatternMatcher, UPat, graph_rewrite
from tinygrad.uop.egraph import uf_find, uf_union, rewrite_all, EGraph, egraph_saturate, egraph_extract, node_cost, _rebuild_tree
# *** test union-find ***
class TestUnionFind(unittest.TestCase):
def test_find_self(self):
a, b = UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)
parent = {a: a, b: b}
self.assertIs(uf_find(parent, a), a)
self.assertIs(uf_find(parent, b), b)
def test_union_basic(self):
a, b = UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)
parent = {a: a, b: b}
size = {a: 1, b: 1}
root = uf_union(parent, size, a, b)
self.assertIs(uf_find(parent, a), uf_find(parent, b))
self.assertIs(root, uf_find(parent, a))
def test_union_chain(self):
a, b, c = UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2), UOp.const(dtypes.int, 3)
parent = {a: a, b: b, c: c}
size = {a: 1, b: 1, c: 1}
uf_union(parent, size, a, b)
uf_union(parent, size, b, c)
self.assertIs(uf_find(parent, a), uf_find(parent, c))
def test_union_idempotent(self):
a, b = UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)
parent = {a: a, b: b}
size = {a: 1, b: 1}
r1 = uf_union(parent, size, a, b)
r2 = uf_union(parent, size, a, b)
self.assertIs(r1, r2)
# *** test rewrite_all ***
class TestRewriteAll(unittest.TestCase):
def test_single_match(self):
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
results = rewrite_all(pm, a + 0)
self.assertEqual(len(results), 1)
self.assertIs(results[0], a)
def test_no_match(self):
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
results = rewrite_all(pm, a + b)
self.assertEqual(len(results), 0)
def test_multiple_matches(self):
pm = PatternMatcher([
(UPat.var("x") + 0, lambda x: x),
(UPat.var("x") * 1, lambda x: x),
])
a = UOp.variable("a", 0, 10)
results = rewrite_all(pm, a + 0)
self.assertEqual(len(results), 1)
self.assertIs(results[0], a)
def test_both_rules_fire(self):
pm = PatternMatcher([
(UPat.var("x") + UPat.var("x"), lambda x: x * 2),
(UPat.var("x") + UPat.var("x"), lambda x: UOp(Ops.SHL, x.dtype, (x, x.const_like(1)))),
])
a = UOp.variable("a", 0, 10)
results = rewrite_all(pm, a + a)
self.assertEqual(len(results), 2)
def test_const_folding(self):
pm = PatternMatcher([
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name="a"),
lambda a: a.const_like(a.src[0].arg + a.src[1].arg) if a.op is Ops.ADD else None),
])
results = rewrite_all(pm, UOp.const(dtypes.int, 3) + UOp.const(dtypes.int, 4))
self.assertEqual(len(results), 1)
self.assertEqual(results[0].arg, 7)
# *** test EGraph class ***
class TestEGraphClass(unittest.TestCase):
def test_init(self):
a = UOp.variable("a", 0, 10)
expr = a + 0
eg = EGraph(expr)
self.assertEqual(len(eg.eclass), len(list(expr.toposort())))
self.assertIn(expr, eg.all_nodes)
def test_add_node(self):
a = UOp.variable("a", 0, 10)
eg = EGraph(a)
b = UOp.variable("b", 0, 10)
eg._add_node(b)
self.assertIn(b, eg.all_nodes)
def test_merge(self):
a = UOp.variable("a", 0, 10)
expr = a + 0
eg = EGraph(expr)
result = eg._merge(expr, a)
self.assertIsNotNone(result)
self.assertIs(uf_find(eg.parent, expr), uf_find(eg.parent, a))
def test_merge_idempotent(self):
a = UOp.variable("a", 0, 10)
eg = EGraph(a)
result = eg._merge(a, a)
self.assertIsNone(result)
# *** test egraph_saturate ***
class TestEGraphSaturate(unittest.TestCase):
def test_identity_rules(self):
pm = PatternMatcher([
(UPat.var("x") + 0, lambda x: x),
(UPat.var("x") * 1, lambda x: x),
])
a = UOp.variable("a", 0, 10)
expr = a + 0
eclass = egraph_saturate(expr, pm)
# a+0 and a should be in the same e-class
a_class = expr_class = None
for canon, members in eclass.items():
if a in members: a_class = canon
if expr in members: expr_class = canon
self.assertIsNotNone(a_class)
self.assertIsNotNone(expr_class)
self.assertIs(a_class, expr_class)
def test_const_fold_saturation(self):
from tinygrad.uop.symbolic import symbolic_simple
c2, c3 = UOp.const(dtypes.int, 2), UOp.const(dtypes.int, 3)
expr = c2 + c3
eclass = egraph_saturate(expr, symbolic_simple)
c5 = UOp.const(dtypes.int, 5)
for canon, members in eclass.items():
if expr in members:
self.assertIn(c5, members, f"expected CONST(5) in eclass of 2+3, got {members}")
return
self.fail("expr not found in any eclass")
def test_no_rules_match(self):
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
eclass = egraph_saturate(a + b, pm)
for canon, members in eclass.items():
self.assertEqual(len(members), 1)
def test_max_iters_respected(self):
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
expr = a + 0
eclass = egraph_saturate(expr, pm, max_iters=1)
a_class = expr_class = None
for canon, members in eclass.items():
if a in members: a_class = canon
if expr in members: expr_class = canon
self.assertIs(a_class, expr_class)
def test_rebuilding_propagates(self):
"""After a*0 merges with 0, rebuilding should create (0+a) which then matches x+0 -> x."""
pm = PatternMatcher([
(UPat.var("x") * 0, lambda x: x.const_like(0)),
(UPat.var("x") + 0, lambda x: x),
])
a = UOp.variable("a", 0, 10)
expr = (a * 0) + a
eclass = egraph_saturate(expr, pm)
expr_cls = a_cls = None
for canon, members in eclass.items():
if expr in members: expr_cls = canon
if a in members: a_cls = canon
self.assertIsNotNone(expr_cls)
self.assertIsNotNone(a_cls)
self.assertIs(expr_cls, a_cls)
def test_rebuilding_chain(self):
"""((a*0)+0)+b should simplify to b through multiple rebuild steps."""
pm = PatternMatcher([
(UPat.var("x") * 0, lambda x: x.const_like(0)),
(UPat.var("x") + 0, lambda x: x),
])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
expr = ((a * 0) + 0) + b
eclass = egraph_saturate(expr, pm)
expr_cls = b_cls = None
for canon, members in eclass.items():
if expr in members: expr_cls = canon
if b in members: b_cls = canon
self.assertIsNotNone(expr_cls)
self.assertIsNotNone(b_cls)
self.assertIs(expr_cls, b_cls)
# *** test egraph_extract ***
class TestEGraphExtract(unittest.TestCase):
def test_extract_identity(self):
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
self.assertIs(egraph_extract(a + 0, pm), a)
def test_extract_mul_identity(self):
pm = PatternMatcher([(UPat.var("x") * 1, lambda x: x)])
a = UOp.variable("a", 0, 10)
self.assertIs(egraph_extract(a * 1, pm), a)
def test_extract_const_fold(self):
from tinygrad.uop.symbolic import symbolic_simple
result = egraph_extract(UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 3), symbolic_simple)
self.assertEqual(result.op, Ops.CONST)
self.assertEqual(result.arg, 5)
def test_extract_chain(self):
pm = PatternMatcher([
(UPat.var("x") + 0, lambda x: x),
(UPat.var("x") * 1, lambda x: x),
])
a = UOp.variable("a", 0, 10)
self.assertIs(egraph_extract((a + 0) * 1, pm), a)
def test_extract_no_change(self):
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
self.assertIs(egraph_extract(a + b, pm), a + b)
def test_extract_prefers_cheaper(self):
pm = PatternMatcher([(UPat.var("x") + UPat.var("x"), lambda x: x * 2)])
a = UOp.variable("a", 0, 10)
result = egraph_extract(a + a, pm)
self.assertEqual(result.op, Ops.ADD) # ADD cost 1 < MUL cost 2
def test_extract_with_symbolic_simple(self):
from tinygrad.uop.symbolic import symbolic_simple
a = UOp.variable("a", 0, 10)
self.assertIs(egraph_extract((a + 0) * 1, symbolic_simple), a)
def test_combine_terms(self):
from tinygrad.uop.symbolic import symbolic
a = UOp.variable("a", 0, 10)
result = egraph_extract(a * 3 + a * 4, symbolic)
self.assertEqual(result.op, Ops.MUL)
self.assertEqual(result.src[1].arg, 7)
# *** tests that REQUIRE rebuilding ***
def test_rebuild_mul_zero_plus(self):
pm = PatternMatcher([
(UPat.var("x") * 0, lambda x: x.const_like(0)),
(UPat.var("x") + 0, lambda x: x),
])
a = UOp.variable("a", 0, 10)
self.assertIs(egraph_extract((a * 0) + a, pm), a)
def test_rebuild_nested_zero(self):
pm = PatternMatcher([
(UPat.var("x") * 0, lambda x: x.const_like(0)),
(UPat.var("x") + 0, lambda x: x),
])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
self.assertIs(egraph_extract(((a * 0) + 0) + b, pm), b)
def test_rebuild_distribute_then_fold(self):
pm = PatternMatcher([(UPat.var("x") * 0, lambda x: x.const_like(0))])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
result = egraph_extract((a + b) * 0, pm)
self.assertEqual(result.op, Ops.CONST)
self.assertEqual(result.arg, 0)
def test_rebuild_symmetric(self):
pm = PatternMatcher([
(UPat.var("x") * 0, lambda x: x.const_like(0)),
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name="a"),
lambda a: a.const_like(a.src[0].arg + a.src[1].arg) if a.op is Ops.ADD else None),
])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
result = egraph_extract((a * 0) + (b * 0), pm)
self.assertEqual(result.op, Ops.CONST)
self.assertEqual(result.arg, 0)
def test_rebuild_with_real_rules(self):
from tinygrad.uop.symbolic import symbolic_simple
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
self.assertIs(egraph_extract((a * 0) + (b * 1), symbolic_simple), b)
def test_rebuild_deep_chain(self):
pm = PatternMatcher([
(UPat.var("x") * 0, lambda x: x.const_like(0)),
(UPat.var("x") + 0, lambda x: x),
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name="a"),
lambda a: a.const_like(a.src[0].arg + a.src[1].arg) if a.op is Ops.ADD else None),
])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
c = UOp.variable("c", 0, 10)
self.assertIs(egraph_extract(((a * 0) + (b * 0)) + c, pm), c)
# *** test cost model ***
class TestCostModel(unittest.TestCase):
def test_const_is_free(self):
self.assertEqual(node_cost(UOp.const(dtypes.int, 0)), 0)
def test_add_is_cheap(self):
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
self.assertEqual(node_cost(a + b), 1)
def test_div_is_expensive(self):
a = UOp.variable("a", 0, 10).cast(dtypes.index)
b = UOp.variable("b", 1, 10).cast(dtypes.index)
self.assertEqual(node_cost(a // b), 5)
def test_mul_more_than_add(self):
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
self.assertGreater(node_cost(a * b), node_cost(a + b))
# *** test e-graph matches greedy rewrite ***
class TestEGraphVsGreedy(unittest.TestCase):
def test_matches_greedy_identity(self):
from tinygrad.uop.ops import graph_rewrite
from tinygrad.uop.symbolic import symbolic_simple
a = UOp.variable("a", 0, 10)
greedy = graph_rewrite(a + 0, symbolic_simple)
egraph = egraph_extract(a + 0, symbolic_simple)
self.assertIs(greedy, egraph)
def test_matches_greedy_const_fold(self):
from tinygrad.uop.ops import graph_rewrite
from tinygrad.uop.symbolic import symbolic_simple
expr = UOp.const(dtypes.int, 10) + UOp.const(dtypes.int, 20)
greedy = graph_rewrite(expr, symbolic_simple)
egraph = egraph_extract(expr, symbolic_simple)
self.assertEqual(greedy.op, Ops.CONST)
self.assertEqual(egraph.op, Ops.CONST)
self.assertEqual(greedy.arg, egraph.arg)
def test_matches_greedy_double_identity(self):
from tinygrad.uop.ops import graph_rewrite
from tinygrad.uop.symbolic import symbolic_simple
a = UOp.variable("a", 0, 10)
expr = (a + 0) * 1
self.assertIs(graph_rewrite(expr, symbolic_simple), a)
self.assertIs(egraph_extract(expr, symbolic_simple), a)
# *** test e-graph beats greedy (phase-ordering problems) ***
# helper PMs that create phase-ordering traps
_pm_strength_reduce = PatternMatcher([
# strength reduction x*2 -> x+x fires FIRST and destroys the x*c form needed by combine-terms
(UPat.var('x') * UPat.cvar('c', vec=False), lambda x,c: x+x if c.arg == 2 else None),
# combine terms: x*c0 + x*c1 -> x*(c0+c1) can only match if both sides are x*c
(UPat.var('x') * UPat.cvar('c0') + UPat.var('x') * UPat.cvar('c1'), lambda x,c0,c1: x*(c0+c1)),
# constant folding
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name='a'),
lambda a: a.const_like(a.src[0].arg + a.src[1].arg) if a.op is Ops.ADD else
a.const_like(a.src[0].arg * a.src[1].arg) if a.op is Ops.MUL else None),
(UPat.var('x') + 0, lambda x: x),
(UPat.var('x') * 1, lambda x: x),
])
_pm_shift_reduce = PatternMatcher([
# strength reduction x*2 -> x<<1 fires FIRST and destroys the x*c form
(UPat.var('x') * UPat.cvar('c', vec=False),
lambda x,c: UOp(Ops.SHL, x.dtype, (x, x.const_like(1))) if c.arg == 2 else None),
(UPat.var('x') * UPat.cvar('c0') + UPat.var('x') * UPat.cvar('c1'), lambda x,c0,c1: x*(c0+c1)),
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name='a'),
lambda a: a.const_like(a.src[0].arg + a.src[1].arg) if a.op is Ops.ADD else
a.const_like(a.src[0].arg * a.src[1].arg) if a.op is Ops.MUL else None),
(UPat.var('x') + 0, lambda x: x),
(UPat.var('x') * 1, lambda x: x),
])
_pm_strength_fold = PatternMatcher([
# strength reduction x*2 -> x+x blocks two-stage folding (x*c1)*c2 -> x*(c1*c2)
(UPat.var('x') * UPat.cvar('c', vec=False), lambda x,c: x+x if c.arg == 2 else None),
((UPat.var('x') * UPat.cvar('c1')) * UPat.cvar('c2'), lambda x,c1,c2: x*(c1*c2)),
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name='a'),
lambda a: a.const_like(a.src[0].arg * a.src[1].arg) if a.op is Ops.MUL else None),
])
def _total_cost(u:UOp) -> int:
return sum(node_cost(n) for n in u.toposort())
class TestEGraphBeatsGreedy(unittest.TestCase):
"""Tests where the e-graph finds a cheaper result than the greedy rewriter due to phase-ordering.
The core problem: when Rule A fires first and transforms a node, it can destroy the pattern
that Rule B needs to match. Rule B would have led to a cheaper result, but the greedy rewriter
never tries it. The e-graph explores BOTH paths and picks the cheapest.
"""
def test_strength_reduce_blocks_combine(self):
"""a*2 + a*3: strength reduction x*2->x+x destroys the x*c form needed by combine-terms x*c0+x*c1->x*(c0+c1)."""
a = UOp.variable("a", 0, 10)
expr = a * 2 + a * 3
greedy = graph_rewrite(expr, _pm_strength_reduce)
egraph = egraph_extract(expr, _pm_strength_reduce)
# greedy: (a+a) + a*3 (cost 4) — strength reduction destroyed the a*2 pattern
self.assertEqual(greedy.op, Ops.ADD)
self.assertGreater(_total_cost(greedy), _total_cost(egraph))
# egraph: a*5 (cost 2) — combine-terms wins because the e-graph explored both paths
self.assertEqual(egraph.op, Ops.MUL)
self.assertEqual(egraph.src[1].arg, 5)
def test_shift_reduce_blocks_combine(self):
"""a*2 + a*3: shift reduction x*2->x<<1 also destroys the combine-terms pattern."""
a = UOp.variable("a", 0, 10)
expr = a * 2 + a * 3
greedy = graph_rewrite(expr, _pm_shift_reduce)
egraph = egraph_extract(expr, _pm_shift_reduce)
self.assertEqual(greedy.op, Ops.ADD)
self.assertGreater(_total_cost(greedy), _total_cost(egraph))
self.assertEqual(egraph.op, Ops.MUL)
self.assertEqual(egraph.src[1].arg, 5)
def test_strength_reduce_chain(self):
"""a*2 + a*3 + a*4: strength reduction causes greedy to miss the combined a*9."""
a = UOp.variable("a", 0, 10)
expr = a * 2 + a * 3 + a * 4
greedy = graph_rewrite(expr, _pm_strength_reduce)
egraph = egraph_extract(expr, _pm_strength_reduce)
self.assertGreater(_total_cost(greedy), _total_cost(egraph))
def test_strength_reduce_blocks_two_stage_fold(self):
"""(a*2)*3: strength reduction x*2->x+x blocks two-stage constant folding (x*c1)*c2->x*(c1*c2)."""
a = UOp.variable("a", 0, 10)
expr = (a * 2) * 3
greedy = graph_rewrite(expr, _pm_strength_fold)
egraph = egraph_extract(expr, _pm_strength_fold)
# greedy: (a+a)*3 (cost 3) — can't fold constants because *2 was rewritten to +
self.assertGreater(_total_cost(greedy), _total_cost(egraph))
# egraph: a*6 (cost 2) — two-stage folding path was explored
self.assertEqual(egraph.op, Ops.MUL)
self.assertEqual(egraph.src[1].arg, 6)
def test_both_sides_strength_reduced(self):
"""a*2 + a*2: both sides get strength-reduced, blocking combine-terms."""
a = UOp.variable("a", 0, 10)
expr = a * 2 + a * 2
greedy = graph_rewrite(expr, _pm_strength_reduce)
egraph = egraph_extract(expr, _pm_strength_reduce)
# greedy: (a+a)+(a+a) — both a*2 were rewritten before combine could fire
# egraph: a*4 — combine-terms path was found
self.assertEqual(egraph.op, Ops.MUL)
self.assertEqual(egraph.src[1].arg, 4)
# both have cost 2 here (shared subexpression), but egraph result is canonical
self.assertLessEqual(_total_cost(egraph), _total_cost(greedy))
# *** test cycle-breaking in extraction ***
class TestExtractionCycles(unittest.TestCase):
def test_self_referencing_eclass(self):
"""x+0 -> x merges x+0 into x's eclass. Extraction must not recurse on the self-reference."""
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
self.assertIs(egraph_extract(a + 0, pm), a)
def test_nested_self_referencing_eclass(self):
"""((a+0)+0)+0 — all merge into a's eclass. Deep self-reference chain."""
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
self.assertIs(egraph_extract(((a + 0) + 0) + 0, pm), a)
def test_mutual_eclass_cycle(self):
"""Two eclasses whose best nodes reference each other — extraction must terminate via cycle-breaking cache."""
x = UOp.variable("x", 0, 10)
y = UOp.variable("y", 0, 10)
one = UOp.const(dtypes.index, 1)
two = UOp.const(dtypes.index, 2)
node1 = x + one # E1's best, child x is in E2
node2 = y + two # E2's best, child y is in E1
eclass_of = {node1: node1, x: node2, node2: node2, y: node1, one: one, two: two}
cost_of = {node1: (2, node1), node2: (2, node2), one: (0, one), two: (0, two)}
# without cycle-breaking cache, this would recurse: E1->E2->E1->...
result = _rebuild_tree(node1, eclass_of, cost_of)
self.assertIsNotNone(result) # just verify it terminates
def test_mutual_rewrite_cycle(self):
"""x+x <-> x*2 mutual rewrite. Both forms in same eclass, extraction picks cheaper (ADD)."""
pm = PatternMatcher([
(UPat.var("x") + UPat.var("x"), lambda x: x * 2),
(UPat.var("x") * UPat.cvar("c", vec=False), lambda x,c: x+x if c.arg == 2 else None),
])
a = UOp.variable("a", 0, 10)
result = egraph_extract(a + a, pm)
self.assertEqual(result.op, Ops.ADD) # ADD cost 1 < MUL cost 2
if __name__ == '__main__':
unittest.main(verbosity=2)

View file

@ -1,7 +1,7 @@
from typing import cast
from dataclasses import replace
import itertools
from tinygrad.helpers import DISABLE_FAST_IDIV, EMULATED_DTYPES, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, TracingKey, Context
from tinygrad.helpers import DISABLE_FAST_IDIV, EMULATED_DTYPES, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, EGRAPH, TracingKey, Context
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
from tinygrad.renderer import Renderer, ProgramSpec
@ -22,6 +22,13 @@ from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_s
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
def _sym_rewrite(sink:UOp, sym_pm:PatternMatcher, extra_pm:PatternMatcher|None=None, ctx=None, name:str|None=None) -> UOp:
"""Symbolic rewrite: uses e-graph extraction when EGRAPH is set, otherwise greedy graph_rewrite."""
if EGRAPH:
from tinygrad.uop.egraph import egraph_rewrite
return egraph_rewrite(sink, sym_pm, extra_pm, ctx=ctx, name=name)
return graph_rewrite(sink, sym_pm+extra_pm if extra_pm is not None else sym_pm, ctx=ctx, name=name)
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
if ren is None: ren = Renderer()
@ -41,7 +48,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
sink = graph_rewrite(sink, pm_split_ranges+pm_flatten_range, ctx={}, name="split ranges")
# symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct)
sink = graph_rewrite(sink, sym+pm_flatten_range, name="initial symbolic")
sink = _sym_rewrite(sink, sym, pm_flatten_range, name="initial symbolic")
# optimize (schedule) the AST
sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges")
@ -53,10 +60,10 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
sink = apply_opts(sink, ren)
# ** expander (expand_rewrite) **
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic")
sink = _sym_rewrite(sink, sym, pm_move_where_on_load, name="postopt symbolic")
# expand
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
sink = _sym_rewrite(sink, sym, pm_pre_expander+pm_group_for_reduce+expander, name="expander")
# add locals
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
@ -74,32 +81,33 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
sink = graph_rewrite(sink, pm_add_loads, name="** add loads (code)")
# devectorize (TODO: does this need opts?)
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
if DEVECTORIZE >= 2: pm_devec_extra = load_store_folding+load_store_indexing
elif DEVECTORIZE: pm_devec_extra = devectorize+load_store_folding+correct_load_store+load_store_indexing
else: pm_devec_extra = load_store_folding+correct_load_store+load_store_indexing
if DEVECTORIZE >= 0: sink = _sym_rewrite(sink, sym, pm_devec_extra, ctx=ren, name="devectorize")
# lower the index dtype to a concrete int
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, ctx=ren.device, name="lower all index dtypes")
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
sink = _sym_rewrite(sink, symbolic, name="post index symbolic")
# optional pre matcher
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
# decompositions
supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, ren.device, bool(DISABLE_FAST_IDIV))
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions")
pm_decomp_extra = get_late_rewrite_patterns(supported_ops, ren.device, bool(DISABLE_FAST_IDIV))
pm_transcend_extra = get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = _sym_rewrite(sink, symbolic_simple, pm_decomp_extra, ctx=ren.device, name="decompositions")
if not is_dtype_supported(dtypes.long, ren.device) or dtypes.long in EMULATED_DTYPES.tolist(dtypes):
sink = graph_rewrite(sink, pm_long_decomp, name="decomp long -> int", bottom_up=True)
for fr, to in [(fr, next((to for to in promo_lattice[fr] if is_dtype_supported(to, ren.device)), dtypes.float))
for fr in EMULATED_DTYPES.tolist(dtypes) if fr in dtypes.floats]:
sink = graph_rewrite(sink, pm_float_decomp, ctx=(fr, to), name=f"decomp {fr} -> {to}", bottom_up=True)
sink = graph_rewrite(sink, pm_transcendental, ctx=ren.device, name="transcendental")
sink = _sym_rewrite(sink, symbolic_simple, pm_transcend_extra, ctx=ren.device, name="transcendental")
# final rules for the renderer (without sym)
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
pm_decomp = symbolic_simple+pm_decomp_extra
pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren.device, name="final rewrite")

View file

@ -180,7 +180,7 @@ SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), Contex
RING, ALL2ALL = ContextVar("RING", 1), ContextVar("ALL2ALL", 0)
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM, EGRAPH = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0), ContextVar("EGRAPH", 0)
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0)
EMULATE, EMULATED_DTYPES = ContextVar("EMULATE", ""), ContextVar("EMULATED_DTYPES", "")
CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else (os.cpu_count() or 1)))

227
tinygrad/uop/egraph.py Normal file
View file

@ -0,0 +1,227 @@
# e-graph (equality saturation) for UOp rewriting
# instead of greedy first-match rewriting, we explore ALL equivalent forms and extract the cheapest
from __future__ import annotations
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, graph_rewrite
# *** union-find (keyed by UOp identity) ***
def uf_find(parent:dict[UOp, UOp], x:UOp) -> UOp:
while parent[x] is not x:
parent[x] = parent[parent[x]]
x = parent[x] # path compression
return x
def uf_union(parent:dict[UOp, UOp], size:dict[UOp, int], a:UOp, b:UOp) -> UOp:
a, b = uf_find(parent, a), uf_find(parent, b)
if a is b: return a
if size[a] < size[b]: a, b = b, a # merge smaller into larger
parent[b] = a
size[a] += size[b]
return a
# *** e-graph core ***
def rewrite_all(pm:PatternMatcher, uop:UOp, ctx=None) -> list[UOp]:
"""Apply ALL matching rewrite rules to uop, returning every distinct result."""
results: list[UOp] = []
seen: dict[UOp, None] = {}
for _, match, early_reject in pm.pdict.get(uop.op, []):
if not early_reject.issubset({u.op for u in uop.src}): continue
try: ret = match(uop, ctx)
except Exception: continue # skip rules that crash on this node (e.g. division by zero in divmod folding)
if ret is not None and ret is not uop and ret not in seen:
results.append(ret)
seen[ret] = None
return results
class EGraph:
"""E-graph with full equality saturation (including rebuilding)."""
__slots__ = ("parent", "size", "eclass", "eclass_uses", "all_nodes")
def __init__(self, root:UOp):
nodes = list(root.toposort())
self.parent: dict[UOp, UOp] = {u: u for u in nodes}
self.size: dict[UOp, int] = {u: 1 for u in nodes}
self.eclass: dict[UOp, dict[UOp, None]] = {u: {u: None} for u in nodes} # canonical -> members
# canonical eclass representative -> dict of nodes that USE this eclass as a child
self.eclass_uses: dict[UOp, dict[UOp, None]] = {u: {} for u in nodes}
self.all_nodes: dict[UOp, None] = dict.fromkeys(nodes)
# build initial parent-child uses
for u in nodes:
for s in u.src:
canon = uf_find(self.parent, s)
self.eclass_uses.setdefault(canon, {})[u] = None
def _add_node(self, u:UOp):
"""Register a new UOp (and its subtree) in the e-graph."""
for sub in u.toposort():
if sub in self.parent: continue
self.parent[sub] = sub
self.size[sub] = 1
self.eclass[sub] = {sub: None}
self.all_nodes[sub] = None
self.eclass_uses[sub] = {}
for s in sub.src:
canon = uf_find(self.parent, s)
self.eclass_uses.setdefault(canon, {})[sub] = None
def _merge(self, a:UOp, b:UOp) -> UOp|None:
"""Merge two e-classes. Returns the winner, or None if already merged."""
ra, rb = uf_find(self.parent, a), uf_find(self.parent, b)
if ra is rb: return None
winner = uf_union(self.parent, self.size, ra, rb)
loser = rb if winner is ra else ra
self.eclass[winner] = {**self.eclass[winner], **self.eclass[loser]}
# merge uses
winner_uses = self.eclass_uses.setdefault(winner, {})
winner_uses.update(self.eclass_uses.pop(loser, {}))
del self.eclass[loser]
return winner
def _canonical(self, u:UOp) -> UOp:
"""Rebuild node with canonical representative for each child's eclass."""
if not u.src: return u
new_src = []
for s in u.src:
canon = uf_find(self.parent, s)
members = self.eclass.get(canon)
if members is not None:
best = min(members, key=lambda m: (len(m.src), m.op.value, m.arg if isinstance(m.arg, (int, float, str)) else 0))
new_src.append(best)
else:
new_src.append(s)
new_src_tuple = tuple(new_src)
if new_src_tuple == u.src: return u
return UOp(u.op, u.dtype, new_src_tuple, u.arg, u.tag)
def _rebuild(self, dirty:dict[UOp, None]) -> list[tuple[UOp, UOp]]:
"""Rebuild parents of dirty eclasses, creating canonical versions."""
new_equalities: list[tuple[UOp, UOp]] = []
affected: dict[UOp, None] = {}
for d in dirty:
canon = uf_find(self.parent, d)
affected.update(self.eclass_uses.get(canon, {}))
for u in affected:
rebuilt = self._canonical(u)
if rebuilt is not u:
if rebuilt in self.parent and uf_find(self.parent, rebuilt) is uf_find(self.parent, u): continue
self._add_node(rebuilt)
new_equalities.append((u, rebuilt))
return new_equalities
def egraph_saturate(root:UOp, pm:PatternMatcher, max_iters:int=10, ctx=None) -> dict[UOp, dict[UOp, None]]:
"""Build an e-graph with full equality saturation (with rebuilding). Returns eclass map."""
eg = EGraph(root)
node_limit = len(eg.all_nodes) * 3 # stop growing at 3x initial size to prevent combinatorial blowup
worklist: dict[UOp, None] = dict(eg.all_nodes) # nodes to match rules on
for _ in range(max_iters):
# phase 1: match rules only on worklist nodes
new_equalities: list[tuple[UOp, UOp]] = []
next_worklist: dict[UOp, None] = {}
prev_nodes = dict(eg.all_nodes)
for u in list(worklist):
if len(eg.all_nodes) >= node_limit: break
for new in rewrite_all(pm, u, ctx):
if new in eg.parent and uf_find(eg.parent, new) is uf_find(eg.parent, u): continue
eg._add_node(new)
new_equalities.append((u, new))
# all newly added nodes (including sub-nodes of rewrite results) go on next worklist
for u in eg.all_nodes:
if u not in prev_nodes: next_worklist[u] = None
if not new_equalities: break
# phase 2: merge eclasses, then rebuild canonical forms (no rule matching in rebuild)
while new_equalities:
dirty: dict[UOp, None] = {}
for a, b in new_equalities:
merged = eg._merge(a, b)
if merged is not None: dirty[merged] = None
if not dirty: break
new_equalities = eg._rebuild(dirty)
for _, b in new_equalities: next_worklist[b] = None
worklist = next_worklist
return eg.eclass
# *** cost model ***
OP_COST: dict[Ops, int] = {
Ops.CONST: 0, Ops.VCONST: 0, Ops.DEFINE_VAR: 0,
Ops.ADD: 1, Ops.MUL: 2, Ops.SUB: 1, Ops.NEG: 1,
Ops.IDIV: 5, Ops.MOD: 5, Ops.FDIV: 3,
Ops.SHL: 1, Ops.SHR: 1,
Ops.AND: 1, Ops.OR: 1, Ops.XOR: 1,
Ops.MAX: 1, Ops.CMPLT: 1, Ops.CMPNE: 1, Ops.CMPEQ: 1,
Ops.CAST: 1, Ops.BITCAST: 1,
Ops.WHERE: 2, Ops.MULACC: 2,
Ops.EXP2: 8, Ops.LOG2: 8, Ops.SIN: 8, Ops.SQRT: 4, Ops.RECIPROCAL: 3,
Ops.POW: 10, Ops.TRUNC: 1,
}
def node_cost(u:UOp) -> int:
c = OP_COST.get(u.op, 3)
# tiebreaker: penalize non-canonical operand order (consts should be on the right for commutative ops)
if len(u.src) == 2 and u.src[0].op is Ops.CONST and u.src[1].op is not Ops.CONST: c += 1
return c
# *** extraction ***
def egraph_extract(root:UOp, pm:PatternMatcher, max_iters:int=10, ctx=None) -> UOp:
"""Run equality saturation on root, then extract the cheapest equivalent expression."""
eclass = egraph_saturate(root, pm, max_iters, ctx)
# build eclass lookup: node -> canonical eclass representative
eclass_of: dict[UOp, UOp] = {}
for canon, members in eclass.items():
for u in members: eclass_of[u] = canon
all_nodes: list[UOp] = [u for members in eclass.values() for u in members]
# bottom-up DP: for each eclass, find the cheapest representative
cost_of: dict[UOp, tuple[int, UOp]] = {} # eclass_canon -> (cost, best_uop)
depth_cache: dict[UOp, int] = {}
def _depth(u:UOp) -> int:
if u in depth_cache: return depth_cache[u]
depth_cache[u] = 0 # break cycles
depth_cache[u] = (1 + max((_depth(s) for s in u.src), default=0)) if u.src else 0
return depth_cache[u]
for u in sorted(all_nodes, key=_depth):
canon = eclass_of[u]
child_cost = 0
for s in u.src:
if (s_canon := eclass_of.get(s)) is not None and s_canon in cost_of: child_cost += cost_of[s_canon][0]
else: child_cost += node_cost(s)
total = node_cost(u) + child_cost
if canon not in cost_of or total < cost_of[canon][0]:
cost_of[canon] = (total, u)
root_canon = eclass_of.get(root)
if root_canon is not None and root_canon in cost_of: return _rebuild_tree(cost_of[root_canon][1], eclass_of, cost_of)
return root
def _rebuild_tree(u:UOp, eclass_of:dict[UOp, UOp], cost_of:dict[UOp, tuple[int, UOp]], cache:dict[UOp, UOp]|None=None) -> UOp:
"""Recursively rebuild a UOp tree, picking the cheapest representative for each child's eclass."""
if not u.src: return u
if cache is None: cache = {}
new_src = []
for s in u.src:
s_canon = eclass_of.get(s)
if s_canon is not None and s_canon in cost_of:
if s_canon in cache: new_src.append(cache[s_canon])
else:
cache[s_canon] = s # placeholder breaks cycles
cache[s_canon] = _rebuild_tree(cost_of[s_canon][1], eclass_of, cost_of, cache)
new_src.append(cache[s_canon])
else:
new_src.append(_rebuild_tree(s, eclass_of, cost_of, cache))
new_src_tuple = tuple(new_src)
return u if new_src_tuple == u.src else UOp(u.op, u.dtype, new_src_tuple, u.arg, u.tag)
# *** graph-level rewrite: drop-in replacement for graph_rewrite when EGRAPH is set ***
def egraph_rewrite(sink:UOp, sym_pm:PatternMatcher, extra_pm:PatternMatcher|None=None, ctx=None, name:str|None=None) -> UOp:
"""Replace graph_rewrite(sink, sym+extra, ctx) with e-graph extraction for sym, then greedy for the rest."""
combined = sym_pm+extra_pm if extra_pm is not None else sym_pm
sink = egraph_extract(sink, combined, ctx=ctx)
return graph_rewrite(sink, combined, ctx=ctx, name=name)