mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
5 commits
master
...
fun_w_egra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7c8aaed31 | ||
|
|
a65d9fea74 | ||
|
|
29b2afa0cb | ||
|
|
d70e255c89 | ||
|
|
9e46535ad3 |
4 changed files with 751 additions and 14 deletions
502
test/null/test_egraph.py
Normal file
502
test/null/test_egraph.py
Normal file
|
|
@ -0,0 +1,502 @@
|
||||||
|
import unittest
|
||||||
|
from tinygrad.dtype import dtypes
|
||||||
|
from tinygrad.uop.ops import Ops, UOp, GroupOp, PatternMatcher, UPat, graph_rewrite
|
||||||
|
from tinygrad.uop.egraph import uf_find, uf_union, rewrite_all, EGraph, egraph_saturate, egraph_extract, node_cost, _rebuild_tree
|
||||||
|
|
||||||
|
# *** test union-find ***
|
||||||
|
|
||||||
|
class TestUnionFind(unittest.TestCase):
|
||||||
|
def test_find_self(self):
|
||||||
|
a, b = UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)
|
||||||
|
parent = {a: a, b: b}
|
||||||
|
self.assertIs(uf_find(parent, a), a)
|
||||||
|
self.assertIs(uf_find(parent, b), b)
|
||||||
|
|
||||||
|
def test_union_basic(self):
|
||||||
|
a, b = UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)
|
||||||
|
parent = {a: a, b: b}
|
||||||
|
size = {a: 1, b: 1}
|
||||||
|
root = uf_union(parent, size, a, b)
|
||||||
|
self.assertIs(uf_find(parent, a), uf_find(parent, b))
|
||||||
|
self.assertIs(root, uf_find(parent, a))
|
||||||
|
|
||||||
|
def test_union_chain(self):
|
||||||
|
a, b, c = UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2), UOp.const(dtypes.int, 3)
|
||||||
|
parent = {a: a, b: b, c: c}
|
||||||
|
size = {a: 1, b: 1, c: 1}
|
||||||
|
uf_union(parent, size, a, b)
|
||||||
|
uf_union(parent, size, b, c)
|
||||||
|
self.assertIs(uf_find(parent, a), uf_find(parent, c))
|
||||||
|
|
||||||
|
def test_union_idempotent(self):
|
||||||
|
a, b = UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)
|
||||||
|
parent = {a: a, b: b}
|
||||||
|
size = {a: 1, b: 1}
|
||||||
|
r1 = uf_union(parent, size, a, b)
|
||||||
|
r2 = uf_union(parent, size, a, b)
|
||||||
|
self.assertIs(r1, r2)
|
||||||
|
|
||||||
|
# *** test rewrite_all ***
|
||||||
|
|
||||||
|
class TestRewriteAll(unittest.TestCase):
|
||||||
|
def test_single_match(self):
|
||||||
|
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
results = rewrite_all(pm, a + 0)
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
self.assertIs(results[0], a)
|
||||||
|
|
||||||
|
def test_no_match(self):
|
||||||
|
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
b = UOp.variable("b", 0, 10)
|
||||||
|
results = rewrite_all(pm, a + b)
|
||||||
|
self.assertEqual(len(results), 0)
|
||||||
|
|
||||||
|
def test_multiple_matches(self):
|
||||||
|
pm = PatternMatcher([
|
||||||
|
(UPat.var("x") + 0, lambda x: x),
|
||||||
|
(UPat.var("x") * 1, lambda x: x),
|
||||||
|
])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
results = rewrite_all(pm, a + 0)
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
self.assertIs(results[0], a)
|
||||||
|
|
||||||
|
def test_both_rules_fire(self):
|
||||||
|
pm = PatternMatcher([
|
||||||
|
(UPat.var("x") + UPat.var("x"), lambda x: x * 2),
|
||||||
|
(UPat.var("x") + UPat.var("x"), lambda x: UOp(Ops.SHL, x.dtype, (x, x.const_like(1)))),
|
||||||
|
])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
results = rewrite_all(pm, a + a)
|
||||||
|
self.assertEqual(len(results), 2)
|
||||||
|
|
||||||
|
def test_const_folding(self):
|
||||||
|
pm = PatternMatcher([
|
||||||
|
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name="a"),
|
||||||
|
lambda a: a.const_like(a.src[0].arg + a.src[1].arg) if a.op is Ops.ADD else None),
|
||||||
|
])
|
||||||
|
results = rewrite_all(pm, UOp.const(dtypes.int, 3) + UOp.const(dtypes.int, 4))
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
self.assertEqual(results[0].arg, 7)
|
||||||
|
|
||||||
|
# *** test EGraph class ***
|
||||||
|
|
||||||
|
class TestEGraphClass(unittest.TestCase):
|
||||||
|
def test_init(self):
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
expr = a + 0
|
||||||
|
eg = EGraph(expr)
|
||||||
|
self.assertEqual(len(eg.eclass), len(list(expr.toposort())))
|
||||||
|
self.assertIn(expr, eg.all_nodes)
|
||||||
|
|
||||||
|
def test_add_node(self):
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
eg = EGraph(a)
|
||||||
|
b = UOp.variable("b", 0, 10)
|
||||||
|
eg._add_node(b)
|
||||||
|
self.assertIn(b, eg.all_nodes)
|
||||||
|
|
||||||
|
def test_merge(self):
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
expr = a + 0
|
||||||
|
eg = EGraph(expr)
|
||||||
|
result = eg._merge(expr, a)
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
self.assertIs(uf_find(eg.parent, expr), uf_find(eg.parent, a))
|
||||||
|
|
||||||
|
def test_merge_idempotent(self):
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
eg = EGraph(a)
|
||||||
|
result = eg._merge(a, a)
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
# *** test egraph_saturate ***
|
||||||
|
|
||||||
|
class TestEGraphSaturate(unittest.TestCase):
|
||||||
|
def test_identity_rules(self):
|
||||||
|
pm = PatternMatcher([
|
||||||
|
(UPat.var("x") + 0, lambda x: x),
|
||||||
|
(UPat.var("x") * 1, lambda x: x),
|
||||||
|
])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
expr = a + 0
|
||||||
|
eclass = egraph_saturate(expr, pm)
|
||||||
|
# a+0 and a should be in the same e-class
|
||||||
|
a_class = expr_class = None
|
||||||
|
for canon, members in eclass.items():
|
||||||
|
if a in members: a_class = canon
|
||||||
|
if expr in members: expr_class = canon
|
||||||
|
self.assertIsNotNone(a_class)
|
||||||
|
self.assertIsNotNone(expr_class)
|
||||||
|
self.assertIs(a_class, expr_class)
|
||||||
|
|
||||||
|
def test_const_fold_saturation(self):
|
||||||
|
from tinygrad.uop.symbolic import symbolic_simple
|
||||||
|
c2, c3 = UOp.const(dtypes.int, 2), UOp.const(dtypes.int, 3)
|
||||||
|
expr = c2 + c3
|
||||||
|
eclass = egraph_saturate(expr, symbolic_simple)
|
||||||
|
c5 = UOp.const(dtypes.int, 5)
|
||||||
|
for canon, members in eclass.items():
|
||||||
|
if expr in members:
|
||||||
|
self.assertIn(c5, members, f"expected CONST(5) in eclass of 2+3, got {members}")
|
||||||
|
return
|
||||||
|
self.fail("expr not found in any eclass")
|
||||||
|
|
||||||
|
def test_no_rules_match(self):
|
||||||
|
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
b = UOp.variable("b", 0, 10)
|
||||||
|
eclass = egraph_saturate(a + b, pm)
|
||||||
|
for canon, members in eclass.items():
|
||||||
|
self.assertEqual(len(members), 1)
|
||||||
|
|
||||||
|
def test_max_iters_respected(self):
|
||||||
|
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
expr = a + 0
|
||||||
|
eclass = egraph_saturate(expr, pm, max_iters=1)
|
||||||
|
a_class = expr_class = None
|
||||||
|
for canon, members in eclass.items():
|
||||||
|
if a in members: a_class = canon
|
||||||
|
if expr in members: expr_class = canon
|
||||||
|
self.assertIs(a_class, expr_class)
|
||||||
|
|
||||||
|
def test_rebuilding_propagates(self):
|
||||||
|
"""After a*0 merges with 0, rebuilding should create (0+a) which then matches x+0 -> x."""
|
||||||
|
pm = PatternMatcher([
|
||||||
|
(UPat.var("x") * 0, lambda x: x.const_like(0)),
|
||||||
|
(UPat.var("x") + 0, lambda x: x),
|
||||||
|
])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
expr = (a * 0) + a
|
||||||
|
eclass = egraph_saturate(expr, pm)
|
||||||
|
expr_cls = a_cls = None
|
||||||
|
for canon, members in eclass.items():
|
||||||
|
if expr in members: expr_cls = canon
|
||||||
|
if a in members: a_cls = canon
|
||||||
|
self.assertIsNotNone(expr_cls)
|
||||||
|
self.assertIsNotNone(a_cls)
|
||||||
|
self.assertIs(expr_cls, a_cls)
|
||||||
|
|
||||||
|
def test_rebuilding_chain(self):
|
||||||
|
"""((a*0)+0)+b should simplify to b through multiple rebuild steps."""
|
||||||
|
pm = PatternMatcher([
|
||||||
|
(UPat.var("x") * 0, lambda x: x.const_like(0)),
|
||||||
|
(UPat.var("x") + 0, lambda x: x),
|
||||||
|
])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
b = UOp.variable("b", 0, 10)
|
||||||
|
expr = ((a * 0) + 0) + b
|
||||||
|
eclass = egraph_saturate(expr, pm)
|
||||||
|
expr_cls = b_cls = None
|
||||||
|
for canon, members in eclass.items():
|
||||||
|
if expr in members: expr_cls = canon
|
||||||
|
if b in members: b_cls = canon
|
||||||
|
self.assertIsNotNone(expr_cls)
|
||||||
|
self.assertIsNotNone(b_cls)
|
||||||
|
self.assertIs(expr_cls, b_cls)
|
||||||
|
|
||||||
|
# *** test egraph_extract ***
|
||||||
|
|
||||||
|
class TestEGraphExtract(unittest.TestCase):
|
||||||
|
def test_extract_identity(self):
|
||||||
|
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
self.assertIs(egraph_extract(a + 0, pm), a)
|
||||||
|
|
||||||
|
def test_extract_mul_identity(self):
|
||||||
|
pm = PatternMatcher([(UPat.var("x") * 1, lambda x: x)])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
self.assertIs(egraph_extract(a * 1, pm), a)
|
||||||
|
|
||||||
|
def test_extract_const_fold(self):
|
||||||
|
from tinygrad.uop.symbolic import symbolic_simple
|
||||||
|
result = egraph_extract(UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 3), symbolic_simple)
|
||||||
|
self.assertEqual(result.op, Ops.CONST)
|
||||||
|
self.assertEqual(result.arg, 5)
|
||||||
|
|
||||||
|
def test_extract_chain(self):
|
||||||
|
pm = PatternMatcher([
|
||||||
|
(UPat.var("x") + 0, lambda x: x),
|
||||||
|
(UPat.var("x") * 1, lambda x: x),
|
||||||
|
])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
self.assertIs(egraph_extract((a + 0) * 1, pm), a)
|
||||||
|
|
||||||
|
def test_extract_no_change(self):
|
||||||
|
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
b = UOp.variable("b", 0, 10)
|
||||||
|
self.assertIs(egraph_extract(a + b, pm), a + b)
|
||||||
|
|
||||||
|
def test_extract_prefers_cheaper(self):
|
||||||
|
pm = PatternMatcher([(UPat.var("x") + UPat.var("x"), lambda x: x * 2)])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
result = egraph_extract(a + a, pm)
|
||||||
|
self.assertEqual(result.op, Ops.ADD) # ADD cost 1 < MUL cost 2
|
||||||
|
|
||||||
|
def test_extract_with_symbolic_simple(self):
|
||||||
|
from tinygrad.uop.symbolic import symbolic_simple
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
self.assertIs(egraph_extract((a + 0) * 1, symbolic_simple), a)
|
||||||
|
|
||||||
|
def test_combine_terms(self):
|
||||||
|
from tinygrad.uop.symbolic import symbolic
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
result = egraph_extract(a * 3 + a * 4, symbolic)
|
||||||
|
self.assertEqual(result.op, Ops.MUL)
|
||||||
|
self.assertEqual(result.src[1].arg, 7)
|
||||||
|
|
||||||
|
# *** tests that REQUIRE rebuilding ***
|
||||||
|
|
||||||
|
def test_rebuild_mul_zero_plus(self):
|
||||||
|
pm = PatternMatcher([
|
||||||
|
(UPat.var("x") * 0, lambda x: x.const_like(0)),
|
||||||
|
(UPat.var("x") + 0, lambda x: x),
|
||||||
|
])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
self.assertIs(egraph_extract((a * 0) + a, pm), a)
|
||||||
|
|
||||||
|
def test_rebuild_nested_zero(self):
|
||||||
|
pm = PatternMatcher([
|
||||||
|
(UPat.var("x") * 0, lambda x: x.const_like(0)),
|
||||||
|
(UPat.var("x") + 0, lambda x: x),
|
||||||
|
])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
b = UOp.variable("b", 0, 10)
|
||||||
|
self.assertIs(egraph_extract(((a * 0) + 0) + b, pm), b)
|
||||||
|
|
||||||
|
def test_rebuild_distribute_then_fold(self):
|
||||||
|
pm = PatternMatcher([(UPat.var("x") * 0, lambda x: x.const_like(0))])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
b = UOp.variable("b", 0, 10)
|
||||||
|
result = egraph_extract((a + b) * 0, pm)
|
||||||
|
self.assertEqual(result.op, Ops.CONST)
|
||||||
|
self.assertEqual(result.arg, 0)
|
||||||
|
|
||||||
|
def test_rebuild_symmetric(self):
|
||||||
|
pm = PatternMatcher([
|
||||||
|
(UPat.var("x") * 0, lambda x: x.const_like(0)),
|
||||||
|
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name="a"),
|
||||||
|
lambda a: a.const_like(a.src[0].arg + a.src[1].arg) if a.op is Ops.ADD else None),
|
||||||
|
])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
b = UOp.variable("b", 0, 10)
|
||||||
|
result = egraph_extract((a * 0) + (b * 0), pm)
|
||||||
|
self.assertEqual(result.op, Ops.CONST)
|
||||||
|
self.assertEqual(result.arg, 0)
|
||||||
|
|
||||||
|
def test_rebuild_with_real_rules(self):
|
||||||
|
from tinygrad.uop.symbolic import symbolic_simple
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
b = UOp.variable("b", 0, 10)
|
||||||
|
self.assertIs(egraph_extract((a * 0) + (b * 1), symbolic_simple), b)
|
||||||
|
|
||||||
|
def test_rebuild_deep_chain(self):
|
||||||
|
pm = PatternMatcher([
|
||||||
|
(UPat.var("x") * 0, lambda x: x.const_like(0)),
|
||||||
|
(UPat.var("x") + 0, lambda x: x),
|
||||||
|
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name="a"),
|
||||||
|
lambda a: a.const_like(a.src[0].arg + a.src[1].arg) if a.op is Ops.ADD else None),
|
||||||
|
])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
b = UOp.variable("b", 0, 10)
|
||||||
|
c = UOp.variable("c", 0, 10)
|
||||||
|
self.assertIs(egraph_extract(((a * 0) + (b * 0)) + c, pm), c)
|
||||||
|
|
||||||
|
# *** test cost model ***
|
||||||
|
|
||||||
|
class TestCostModel(unittest.TestCase):
|
||||||
|
def test_const_is_free(self):
|
||||||
|
self.assertEqual(node_cost(UOp.const(dtypes.int, 0)), 0)
|
||||||
|
|
||||||
|
def test_add_is_cheap(self):
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
b = UOp.variable("b", 0, 10)
|
||||||
|
self.assertEqual(node_cost(a + b), 1)
|
||||||
|
|
||||||
|
def test_div_is_expensive(self):
|
||||||
|
a = UOp.variable("a", 0, 10).cast(dtypes.index)
|
||||||
|
b = UOp.variable("b", 1, 10).cast(dtypes.index)
|
||||||
|
self.assertEqual(node_cost(a // b), 5)
|
||||||
|
|
||||||
|
def test_mul_more_than_add(self):
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
b = UOp.variable("b", 0, 10)
|
||||||
|
self.assertGreater(node_cost(a * b), node_cost(a + b))
|
||||||
|
|
||||||
|
# *** test e-graph matches greedy rewrite ***
|
||||||
|
|
||||||
|
class TestEGraphVsGreedy(unittest.TestCase):
|
||||||
|
def test_matches_greedy_identity(self):
|
||||||
|
from tinygrad.uop.ops import graph_rewrite
|
||||||
|
from tinygrad.uop.symbolic import symbolic_simple
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
greedy = graph_rewrite(a + 0, symbolic_simple)
|
||||||
|
egraph = egraph_extract(a + 0, symbolic_simple)
|
||||||
|
self.assertIs(greedy, egraph)
|
||||||
|
|
||||||
|
def test_matches_greedy_const_fold(self):
|
||||||
|
from tinygrad.uop.ops import graph_rewrite
|
||||||
|
from tinygrad.uop.symbolic import symbolic_simple
|
||||||
|
expr = UOp.const(dtypes.int, 10) + UOp.const(dtypes.int, 20)
|
||||||
|
greedy = graph_rewrite(expr, symbolic_simple)
|
||||||
|
egraph = egraph_extract(expr, symbolic_simple)
|
||||||
|
self.assertEqual(greedy.op, Ops.CONST)
|
||||||
|
self.assertEqual(egraph.op, Ops.CONST)
|
||||||
|
self.assertEqual(greedy.arg, egraph.arg)
|
||||||
|
|
||||||
|
def test_matches_greedy_double_identity(self):
|
||||||
|
from tinygrad.uop.ops import graph_rewrite
|
||||||
|
from tinygrad.uop.symbolic import symbolic_simple
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
expr = (a + 0) * 1
|
||||||
|
self.assertIs(graph_rewrite(expr, symbolic_simple), a)
|
||||||
|
self.assertIs(egraph_extract(expr, symbolic_simple), a)
|
||||||
|
|
||||||
|
# *** test e-graph beats greedy (phase-ordering problems) ***
|
||||||
|
|
||||||
|
# helper PMs that create phase-ordering traps
|
||||||
|
_pm_strength_reduce = PatternMatcher([
|
||||||
|
# strength reduction x*2 -> x+x fires FIRST and destroys the x*c form needed by combine-terms
|
||||||
|
(UPat.var('x') * UPat.cvar('c', vec=False), lambda x,c: x+x if c.arg == 2 else None),
|
||||||
|
# combine terms: x*c0 + x*c1 -> x*(c0+c1) can only match if both sides are x*c
|
||||||
|
(UPat.var('x') * UPat.cvar('c0') + UPat.var('x') * UPat.cvar('c1'), lambda x,c0,c1: x*(c0+c1)),
|
||||||
|
# constant folding
|
||||||
|
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name='a'),
|
||||||
|
lambda a: a.const_like(a.src[0].arg + a.src[1].arg) if a.op is Ops.ADD else
|
||||||
|
a.const_like(a.src[0].arg * a.src[1].arg) if a.op is Ops.MUL else None),
|
||||||
|
(UPat.var('x') + 0, lambda x: x),
|
||||||
|
(UPat.var('x') * 1, lambda x: x),
|
||||||
|
])
|
||||||
|
|
||||||
|
_pm_shift_reduce = PatternMatcher([
|
||||||
|
# strength reduction x*2 -> x<<1 fires FIRST and destroys the x*c form
|
||||||
|
(UPat.var('x') * UPat.cvar('c', vec=False),
|
||||||
|
lambda x,c: UOp(Ops.SHL, x.dtype, (x, x.const_like(1))) if c.arg == 2 else None),
|
||||||
|
(UPat.var('x') * UPat.cvar('c0') + UPat.var('x') * UPat.cvar('c1'), lambda x,c0,c1: x*(c0+c1)),
|
||||||
|
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name='a'),
|
||||||
|
lambda a: a.const_like(a.src[0].arg + a.src[1].arg) if a.op is Ops.ADD else
|
||||||
|
a.const_like(a.src[0].arg * a.src[1].arg) if a.op is Ops.MUL else None),
|
||||||
|
(UPat.var('x') + 0, lambda x: x),
|
||||||
|
(UPat.var('x') * 1, lambda x: x),
|
||||||
|
])
|
||||||
|
|
||||||
|
_pm_strength_fold = PatternMatcher([
|
||||||
|
# strength reduction x*2 -> x+x blocks two-stage folding (x*c1)*c2 -> x*(c1*c2)
|
||||||
|
(UPat.var('x') * UPat.cvar('c', vec=False), lambda x,c: x+x if c.arg == 2 else None),
|
||||||
|
((UPat.var('x') * UPat.cvar('c1')) * UPat.cvar('c2'), lambda x,c1,c2: x*(c1*c2)),
|
||||||
|
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name='a'),
|
||||||
|
lambda a: a.const_like(a.src[0].arg * a.src[1].arg) if a.op is Ops.MUL else None),
|
||||||
|
])
|
||||||
|
|
||||||
|
def _total_cost(u:UOp) -> int:
|
||||||
|
return sum(node_cost(n) for n in u.toposort())
|
||||||
|
|
||||||
|
class TestEGraphBeatsGreedy(unittest.TestCase):
|
||||||
|
"""Tests where the e-graph finds a cheaper result than the greedy rewriter due to phase-ordering.
|
||||||
|
|
||||||
|
The core problem: when Rule A fires first and transforms a node, it can destroy the pattern
|
||||||
|
that Rule B needs to match. Rule B would have led to a cheaper result, but the greedy rewriter
|
||||||
|
never tries it. The e-graph explores BOTH paths and picks the cheapest.
|
||||||
|
"""
|
||||||
|
def test_strength_reduce_blocks_combine(self):
|
||||||
|
"""a*2 + a*3: strength reduction x*2->x+x destroys the x*c form needed by combine-terms x*c0+x*c1->x*(c0+c1)."""
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
expr = a * 2 + a * 3
|
||||||
|
greedy = graph_rewrite(expr, _pm_strength_reduce)
|
||||||
|
egraph = egraph_extract(expr, _pm_strength_reduce)
|
||||||
|
# greedy: (a+a) + a*3 (cost 4) — strength reduction destroyed the a*2 pattern
|
||||||
|
self.assertEqual(greedy.op, Ops.ADD)
|
||||||
|
self.assertGreater(_total_cost(greedy), _total_cost(egraph))
|
||||||
|
# egraph: a*5 (cost 2) — combine-terms wins because the e-graph explored both paths
|
||||||
|
self.assertEqual(egraph.op, Ops.MUL)
|
||||||
|
self.assertEqual(egraph.src[1].arg, 5)
|
||||||
|
|
||||||
|
def test_shift_reduce_blocks_combine(self):
|
||||||
|
"""a*2 + a*3: shift reduction x*2->x<<1 also destroys the combine-terms pattern."""
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
expr = a * 2 + a * 3
|
||||||
|
greedy = graph_rewrite(expr, _pm_shift_reduce)
|
||||||
|
egraph = egraph_extract(expr, _pm_shift_reduce)
|
||||||
|
self.assertEqual(greedy.op, Ops.ADD)
|
||||||
|
self.assertGreater(_total_cost(greedy), _total_cost(egraph))
|
||||||
|
self.assertEqual(egraph.op, Ops.MUL)
|
||||||
|
self.assertEqual(egraph.src[1].arg, 5)
|
||||||
|
|
||||||
|
def test_strength_reduce_chain(self):
|
||||||
|
"""a*2 + a*3 + a*4: strength reduction causes greedy to miss the combined a*9."""
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
expr = a * 2 + a * 3 + a * 4
|
||||||
|
greedy = graph_rewrite(expr, _pm_strength_reduce)
|
||||||
|
egraph = egraph_extract(expr, _pm_strength_reduce)
|
||||||
|
self.assertGreater(_total_cost(greedy), _total_cost(egraph))
|
||||||
|
|
||||||
|
def test_strength_reduce_blocks_two_stage_fold(self):
|
||||||
|
"""(a*2)*3: strength reduction x*2->x+x blocks two-stage constant folding (x*c1)*c2->x*(c1*c2)."""
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
expr = (a * 2) * 3
|
||||||
|
greedy = graph_rewrite(expr, _pm_strength_fold)
|
||||||
|
egraph = egraph_extract(expr, _pm_strength_fold)
|
||||||
|
# greedy: (a+a)*3 (cost 3) — can't fold constants because *2 was rewritten to +
|
||||||
|
self.assertGreater(_total_cost(greedy), _total_cost(egraph))
|
||||||
|
# egraph: a*6 (cost 2) — two-stage folding path was explored
|
||||||
|
self.assertEqual(egraph.op, Ops.MUL)
|
||||||
|
self.assertEqual(egraph.src[1].arg, 6)
|
||||||
|
|
||||||
|
def test_both_sides_strength_reduced(self):
|
||||||
|
"""a*2 + a*2: both sides get strength-reduced, blocking combine-terms."""
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
expr = a * 2 + a * 2
|
||||||
|
greedy = graph_rewrite(expr, _pm_strength_reduce)
|
||||||
|
egraph = egraph_extract(expr, _pm_strength_reduce)
|
||||||
|
# greedy: (a+a)+(a+a) — both a*2 were rewritten before combine could fire
|
||||||
|
# egraph: a*4 — combine-terms path was found
|
||||||
|
self.assertEqual(egraph.op, Ops.MUL)
|
||||||
|
self.assertEqual(egraph.src[1].arg, 4)
|
||||||
|
# both have cost 2 here (shared subexpression), but egraph result is canonical
|
||||||
|
self.assertLessEqual(_total_cost(egraph), _total_cost(greedy))
|
||||||
|
|
||||||
|
# *** test cycle-breaking in extraction ***
|
||||||
|
|
||||||
|
class TestExtractionCycles(unittest.TestCase):
|
||||||
|
def test_self_referencing_eclass(self):
|
||||||
|
"""x+0 -> x merges x+0 into x's eclass. Extraction must not recurse on the self-reference."""
|
||||||
|
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
self.assertIs(egraph_extract(a + 0, pm), a)
|
||||||
|
|
||||||
|
def test_nested_self_referencing_eclass(self):
|
||||||
|
"""((a+0)+0)+0 — all merge into a's eclass. Deep self-reference chain."""
|
||||||
|
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
self.assertIs(egraph_extract(((a + 0) + 0) + 0, pm), a)
|
||||||
|
|
||||||
|
def test_mutual_eclass_cycle(self):
|
||||||
|
"""Two eclasses whose best nodes reference each other — extraction must terminate via cycle-breaking cache."""
|
||||||
|
x = UOp.variable("x", 0, 10)
|
||||||
|
y = UOp.variable("y", 0, 10)
|
||||||
|
one = UOp.const(dtypes.index, 1)
|
||||||
|
two = UOp.const(dtypes.index, 2)
|
||||||
|
node1 = x + one # E1's best, child x is in E2
|
||||||
|
node2 = y + two # E2's best, child y is in E1
|
||||||
|
eclass_of = {node1: node1, x: node2, node2: node2, y: node1, one: one, two: two}
|
||||||
|
cost_of = {node1: (2, node1), node2: (2, node2), one: (0, one), two: (0, two)}
|
||||||
|
# without cycle-breaking cache, this would recurse: E1->E2->E1->...
|
||||||
|
result = _rebuild_tree(node1, eclass_of, cost_of)
|
||||||
|
self.assertIsNotNone(result) # just verify it terminates
|
||||||
|
|
||||||
|
def test_mutual_rewrite_cycle(self):
|
||||||
|
"""x+x <-> x*2 mutual rewrite. Both forms in same eclass, extraction picks cheaper (ADD)."""
|
||||||
|
pm = PatternMatcher([
|
||||||
|
(UPat.var("x") + UPat.var("x"), lambda x: x * 2),
|
||||||
|
(UPat.var("x") * UPat.cvar("c", vec=False), lambda x,c: x+x if c.arg == 2 else None),
|
||||||
|
])
|
||||||
|
a = UOp.variable("a", 0, 10)
|
||||||
|
result = egraph_extract(a + a, pm)
|
||||||
|
self.assertEqual(result.op, Ops.ADD) # ADD cost 1 < MUL cost 2
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main(verbosity=2)
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
import itertools
|
import itertools
|
||||||
from tinygrad.helpers import DISABLE_FAST_IDIV, EMULATED_DTYPES, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, TracingKey, Context
|
from tinygrad.helpers import DISABLE_FAST_IDIV, EMULATED_DTYPES, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, EGRAPH, TracingKey, Context
|
||||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender
|
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender
|
||||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
||||||
from tinygrad.renderer import Renderer, ProgramSpec
|
from tinygrad.renderer import Renderer, ProgramSpec
|
||||||
|
|
@ -22,6 +22,13 @@ from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_s
|
||||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar
|
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar
|
||||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||||
|
|
||||||
|
def _sym_rewrite(sink:UOp, sym_pm:PatternMatcher, extra_pm:PatternMatcher|None=None, ctx=None, name:str|None=None) -> UOp:
|
||||||
|
"""Symbolic rewrite: uses e-graph extraction when EGRAPH is set, otherwise greedy graph_rewrite."""
|
||||||
|
if EGRAPH:
|
||||||
|
from tinygrad.uop.egraph import egraph_rewrite
|
||||||
|
return egraph_rewrite(sink, sym_pm, extra_pm, ctx=ctx, name=name)
|
||||||
|
return graph_rewrite(sink, sym_pm+extra_pm if extra_pm is not None else sym_pm, ctx=ctx, name=name)
|
||||||
|
|
||||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||||
if ren is None: ren = Renderer()
|
if ren is None: ren = Renderer()
|
||||||
|
|
||||||
|
|
@ -41,7 +48,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||||
sink = graph_rewrite(sink, pm_split_ranges+pm_flatten_range, ctx={}, name="split ranges")
|
sink = graph_rewrite(sink, pm_split_ranges+pm_flatten_range, ctx={}, name="split ranges")
|
||||||
|
|
||||||
# symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct)
|
# symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct)
|
||||||
sink = graph_rewrite(sink, sym+pm_flatten_range, name="initial symbolic")
|
sink = _sym_rewrite(sink, sym, pm_flatten_range, name="initial symbolic")
|
||||||
|
|
||||||
# optimize (schedule) the AST
|
# optimize (schedule) the AST
|
||||||
sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges")
|
sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges")
|
||||||
|
|
@ -53,10 +60,10 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||||
sink = apply_opts(sink, ren)
|
sink = apply_opts(sink, ren)
|
||||||
|
|
||||||
# ** expander (expand_rewrite) **
|
# ** expander (expand_rewrite) **
|
||||||
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic")
|
sink = _sym_rewrite(sink, sym, pm_move_where_on_load, name="postopt symbolic")
|
||||||
|
|
||||||
# expand
|
# expand
|
||||||
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
|
sink = _sym_rewrite(sink, sym, pm_pre_expander+pm_group_for_reduce+expander, name="expander")
|
||||||
|
|
||||||
# add locals
|
# add locals
|
||||||
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
|
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
|
||||||
|
|
@ -74,32 +81,33 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||||
sink = graph_rewrite(sink, pm_add_loads, name="** add loads (code)")
|
sink = graph_rewrite(sink, pm_add_loads, name="** add loads (code)")
|
||||||
|
|
||||||
# devectorize (TODO: does this need opts?)
|
# devectorize (TODO: does this need opts?)
|
||||||
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
|
if DEVECTORIZE >= 2: pm_devec_extra = load_store_folding+load_store_indexing
|
||||||
elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
|
elif DEVECTORIZE: pm_devec_extra = devectorize+load_store_folding+correct_load_store+load_store_indexing
|
||||||
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
|
else: pm_devec_extra = load_store_folding+correct_load_store+load_store_indexing
|
||||||
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
|
if DEVECTORIZE >= 0: sink = _sym_rewrite(sink, sym, pm_devec_extra, ctx=ren, name="devectorize")
|
||||||
|
|
||||||
# lower the index dtype to a concrete int
|
# lower the index dtype to a concrete int
|
||||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, ctx=ren.device, name="lower all index dtypes")
|
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, ctx=ren.device, name="lower all index dtypes")
|
||||||
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
|
sink = _sym_rewrite(sink, symbolic, name="post index symbolic")
|
||||||
|
|
||||||
# optional pre matcher
|
# optional pre matcher
|
||||||
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
|
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
|
||||||
|
|
||||||
# decompositions
|
# decompositions
|
||||||
supported_ops = tuple(ren.code_for_op.keys())
|
supported_ops = tuple(ren.code_for_op.keys())
|
||||||
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, ren.device, bool(DISABLE_FAST_IDIV))
|
pm_decomp_extra = get_late_rewrite_patterns(supported_ops, ren.device, bool(DISABLE_FAST_IDIV))
|
||||||
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
|
pm_transcend_extra = get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
|
||||||
sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions")
|
sink = _sym_rewrite(sink, symbolic_simple, pm_decomp_extra, ctx=ren.device, name="decompositions")
|
||||||
if not is_dtype_supported(dtypes.long, ren.device) or dtypes.long in EMULATED_DTYPES.tolist(dtypes):
|
if not is_dtype_supported(dtypes.long, ren.device) or dtypes.long in EMULATED_DTYPES.tolist(dtypes):
|
||||||
sink = graph_rewrite(sink, pm_long_decomp, name="decomp long -> int", bottom_up=True)
|
sink = graph_rewrite(sink, pm_long_decomp, name="decomp long -> int", bottom_up=True)
|
||||||
for fr, to in [(fr, next((to for to in promo_lattice[fr] if is_dtype_supported(to, ren.device)), dtypes.float))
|
for fr, to in [(fr, next((to for to in promo_lattice[fr] if is_dtype_supported(to, ren.device)), dtypes.float))
|
||||||
for fr in EMULATED_DTYPES.tolist(dtypes) if fr in dtypes.floats]:
|
for fr in EMULATED_DTYPES.tolist(dtypes) if fr in dtypes.floats]:
|
||||||
sink = graph_rewrite(sink, pm_float_decomp, ctx=(fr, to), name=f"decomp {fr} -> {to}", bottom_up=True)
|
sink = graph_rewrite(sink, pm_float_decomp, ctx=(fr, to), name=f"decomp {fr} -> {to}", bottom_up=True)
|
||||||
sink = graph_rewrite(sink, pm_transcendental, ctx=ren.device, name="transcendental")
|
sink = _sym_rewrite(sink, symbolic_simple, pm_transcend_extra, ctx=ren.device, name="transcendental")
|
||||||
|
|
||||||
# final rules for the renderer (without sym)
|
# final rules for the renderer (without sym)
|
||||||
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
|
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
|
||||||
|
pm_decomp = symbolic_simple+pm_decomp_extra
|
||||||
pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends
|
pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends
|
||||||
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren.device, name="final rewrite")
|
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren.device, name="final rewrite")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -180,7 +180,7 @@ SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), Contex
|
||||||
RING, ALL2ALL = ContextVar("RING", 1), ContextVar("ALL2ALL", 0)
|
RING, ALL2ALL = ContextVar("RING", 1), ContextVar("ALL2ALL", 0)
|
||||||
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
|
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
|
||||||
VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
|
VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
|
||||||
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
|
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM, EGRAPH = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0), ContextVar("EGRAPH", 0)
|
||||||
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0)
|
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0)
|
||||||
EMULATE, EMULATED_DTYPES = ContextVar("EMULATE", ""), ContextVar("EMULATED_DTYPES", "")
|
EMULATE, EMULATED_DTYPES = ContextVar("EMULATE", ""), ContextVar("EMULATED_DTYPES", "")
|
||||||
CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else (os.cpu_count() or 1)))
|
CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else (os.cpu_count() or 1)))
|
||||||
|
|
|
||||||
227
tinygrad/uop/egraph.py
Normal file
227
tinygrad/uop/egraph.py
Normal file
|
|
@ -0,0 +1,227 @@
|
||||||
|
# e-graph (equality saturation) for UOp rewriting
|
||||||
|
# instead of greedy first-match rewriting, we explore ALL equivalent forms and extract the cheapest
|
||||||
|
from __future__ import annotations
|
||||||
|
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, graph_rewrite
|
||||||
|
|
||||||
|
# *** union-find (keyed by UOp identity) ***
|
||||||
|
|
||||||
|
def uf_find(parent:dict[UOp, UOp], x:UOp) -> UOp:
|
||||||
|
while parent[x] is not x:
|
||||||
|
parent[x] = parent[parent[x]]
|
||||||
|
x = parent[x] # path compression
|
||||||
|
return x
|
||||||
|
|
||||||
|
def uf_union(parent:dict[UOp, UOp], size:dict[UOp, int], a:UOp, b:UOp) -> UOp:
|
||||||
|
a, b = uf_find(parent, a), uf_find(parent, b)
|
||||||
|
if a is b: return a
|
||||||
|
if size[a] < size[b]: a, b = b, a # merge smaller into larger
|
||||||
|
parent[b] = a
|
||||||
|
size[a] += size[b]
|
||||||
|
return a
|
||||||
|
|
||||||
|
# *** e-graph core ***
|
||||||
|
|
||||||
|
def rewrite_all(pm:PatternMatcher, uop:UOp, ctx=None) -> list[UOp]:
|
||||||
|
"""Apply ALL matching rewrite rules to uop, returning every distinct result."""
|
||||||
|
results: list[UOp] = []
|
||||||
|
seen: dict[UOp, None] = {}
|
||||||
|
for _, match, early_reject in pm.pdict.get(uop.op, []):
|
||||||
|
if not early_reject.issubset({u.op for u in uop.src}): continue
|
||||||
|
try: ret = match(uop, ctx)
|
||||||
|
except Exception: continue # skip rules that crash on this node (e.g. division by zero in divmod folding)
|
||||||
|
if ret is not None and ret is not uop and ret not in seen:
|
||||||
|
results.append(ret)
|
||||||
|
seen[ret] = None
|
||||||
|
return results
|
||||||
|
|
||||||
|
class EGraph:
|
||||||
|
"""E-graph with full equality saturation (including rebuilding)."""
|
||||||
|
__slots__ = ("parent", "size", "eclass", "eclass_uses", "all_nodes")
|
||||||
|
def __init__(self, root:UOp):
|
||||||
|
nodes = list(root.toposort())
|
||||||
|
self.parent: dict[UOp, UOp] = {u: u for u in nodes}
|
||||||
|
self.size: dict[UOp, int] = {u: 1 for u in nodes}
|
||||||
|
self.eclass: dict[UOp, dict[UOp, None]] = {u: {u: None} for u in nodes} # canonical -> members
|
||||||
|
# canonical eclass representative -> dict of nodes that USE this eclass as a child
|
||||||
|
self.eclass_uses: dict[UOp, dict[UOp, None]] = {u: {} for u in nodes}
|
||||||
|
self.all_nodes: dict[UOp, None] = dict.fromkeys(nodes)
|
||||||
|
# build initial parent-child uses
|
||||||
|
for u in nodes:
|
||||||
|
for s in u.src:
|
||||||
|
canon = uf_find(self.parent, s)
|
||||||
|
self.eclass_uses.setdefault(canon, {})[u] = None
|
||||||
|
|
||||||
|
def _add_node(self, u:UOp):
|
||||||
|
"""Register a new UOp (and its subtree) in the e-graph."""
|
||||||
|
for sub in u.toposort():
|
||||||
|
if sub in self.parent: continue
|
||||||
|
self.parent[sub] = sub
|
||||||
|
self.size[sub] = 1
|
||||||
|
self.eclass[sub] = {sub: None}
|
||||||
|
self.all_nodes[sub] = None
|
||||||
|
self.eclass_uses[sub] = {}
|
||||||
|
for s in sub.src:
|
||||||
|
canon = uf_find(self.parent, s)
|
||||||
|
self.eclass_uses.setdefault(canon, {})[sub] = None
|
||||||
|
|
||||||
|
def _merge(self, a:UOp, b:UOp) -> UOp|None:
|
||||||
|
"""Merge two e-classes. Returns the winner, or None if already merged."""
|
||||||
|
ra, rb = uf_find(self.parent, a), uf_find(self.parent, b)
|
||||||
|
if ra is rb: return None
|
||||||
|
winner = uf_union(self.parent, self.size, ra, rb)
|
||||||
|
loser = rb if winner is ra else ra
|
||||||
|
self.eclass[winner] = {**self.eclass[winner], **self.eclass[loser]}
|
||||||
|
# merge uses
|
||||||
|
winner_uses = self.eclass_uses.setdefault(winner, {})
|
||||||
|
winner_uses.update(self.eclass_uses.pop(loser, {}))
|
||||||
|
del self.eclass[loser]
|
||||||
|
return winner
|
||||||
|
|
||||||
|
def _canonical(self, u:UOp) -> UOp:
|
||||||
|
"""Rebuild node with canonical representative for each child's eclass."""
|
||||||
|
if not u.src: return u
|
||||||
|
new_src = []
|
||||||
|
for s in u.src:
|
||||||
|
canon = uf_find(self.parent, s)
|
||||||
|
members = self.eclass.get(canon)
|
||||||
|
if members is not None:
|
||||||
|
best = min(members, key=lambda m: (len(m.src), m.op.value, m.arg if isinstance(m.arg, (int, float, str)) else 0))
|
||||||
|
new_src.append(best)
|
||||||
|
else:
|
||||||
|
new_src.append(s)
|
||||||
|
new_src_tuple = tuple(new_src)
|
||||||
|
if new_src_tuple == u.src: return u
|
||||||
|
return UOp(u.op, u.dtype, new_src_tuple, u.arg, u.tag)
|
||||||
|
|
||||||
|
def _rebuild(self, dirty:dict[UOp, None]) -> list[tuple[UOp, UOp]]:
|
||||||
|
"""Rebuild parents of dirty eclasses, creating canonical versions."""
|
||||||
|
new_equalities: list[tuple[UOp, UOp]] = []
|
||||||
|
affected: dict[UOp, None] = {}
|
||||||
|
for d in dirty:
|
||||||
|
canon = uf_find(self.parent, d)
|
||||||
|
affected.update(self.eclass_uses.get(canon, {}))
|
||||||
|
for u in affected:
|
||||||
|
rebuilt = self._canonical(u)
|
||||||
|
if rebuilt is not u:
|
||||||
|
if rebuilt in self.parent and uf_find(self.parent, rebuilt) is uf_find(self.parent, u): continue
|
||||||
|
self._add_node(rebuilt)
|
||||||
|
new_equalities.append((u, rebuilt))
|
||||||
|
return new_equalities
|
||||||
|
|
||||||
|
def egraph_saturate(root:UOp, pm:PatternMatcher, max_iters:int=10, ctx=None) -> dict[UOp, dict[UOp, None]]:
|
||||||
|
"""Build an e-graph with full equality saturation (with rebuilding). Returns eclass map."""
|
||||||
|
eg = EGraph(root)
|
||||||
|
node_limit = len(eg.all_nodes) * 3 # stop growing at 3x initial size to prevent combinatorial blowup
|
||||||
|
worklist: dict[UOp, None] = dict(eg.all_nodes) # nodes to match rules on
|
||||||
|
for _ in range(max_iters):
|
||||||
|
# phase 1: match rules only on worklist nodes
|
||||||
|
new_equalities: list[tuple[UOp, UOp]] = []
|
||||||
|
next_worklist: dict[UOp, None] = {}
|
||||||
|
prev_nodes = dict(eg.all_nodes)
|
||||||
|
for u in list(worklist):
|
||||||
|
if len(eg.all_nodes) >= node_limit: break
|
||||||
|
for new in rewrite_all(pm, u, ctx):
|
||||||
|
if new in eg.parent and uf_find(eg.parent, new) is uf_find(eg.parent, u): continue
|
||||||
|
eg._add_node(new)
|
||||||
|
new_equalities.append((u, new))
|
||||||
|
# all newly added nodes (including sub-nodes of rewrite results) go on next worklist
|
||||||
|
for u in eg.all_nodes:
|
||||||
|
if u not in prev_nodes: next_worklist[u] = None
|
||||||
|
if not new_equalities: break
|
||||||
|
|
||||||
|
# phase 2: merge eclasses, then rebuild canonical forms (no rule matching in rebuild)
|
||||||
|
while new_equalities:
|
||||||
|
dirty: dict[UOp, None] = {}
|
||||||
|
for a, b in new_equalities:
|
||||||
|
merged = eg._merge(a, b)
|
||||||
|
if merged is not None: dirty[merged] = None
|
||||||
|
if not dirty: break
|
||||||
|
new_equalities = eg._rebuild(dirty)
|
||||||
|
for _, b in new_equalities: next_worklist[b] = None
|
||||||
|
worklist = next_worklist
|
||||||
|
|
||||||
|
return eg.eclass
|
||||||
|
|
||||||
|
# *** cost model ***
|
||||||
|
|
||||||
|
OP_COST: dict[Ops, int] = {
|
||||||
|
Ops.CONST: 0, Ops.VCONST: 0, Ops.DEFINE_VAR: 0,
|
||||||
|
Ops.ADD: 1, Ops.MUL: 2, Ops.SUB: 1, Ops.NEG: 1,
|
||||||
|
Ops.IDIV: 5, Ops.MOD: 5, Ops.FDIV: 3,
|
||||||
|
Ops.SHL: 1, Ops.SHR: 1,
|
||||||
|
Ops.AND: 1, Ops.OR: 1, Ops.XOR: 1,
|
||||||
|
Ops.MAX: 1, Ops.CMPLT: 1, Ops.CMPNE: 1, Ops.CMPEQ: 1,
|
||||||
|
Ops.CAST: 1, Ops.BITCAST: 1,
|
||||||
|
Ops.WHERE: 2, Ops.MULACC: 2,
|
||||||
|
Ops.EXP2: 8, Ops.LOG2: 8, Ops.SIN: 8, Ops.SQRT: 4, Ops.RECIPROCAL: 3,
|
||||||
|
Ops.POW: 10, Ops.TRUNC: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
def node_cost(u:UOp) -> int:
|
||||||
|
c = OP_COST.get(u.op, 3)
|
||||||
|
# tiebreaker: penalize non-canonical operand order (consts should be on the right for commutative ops)
|
||||||
|
if len(u.src) == 2 and u.src[0].op is Ops.CONST and u.src[1].op is not Ops.CONST: c += 1
|
||||||
|
return c
|
||||||
|
|
||||||
|
# *** extraction ***
|
||||||
|
|
||||||
|
def egraph_extract(root:UOp, pm:PatternMatcher, max_iters:int=10, ctx=None) -> UOp:
|
||||||
|
"""Run equality saturation on root, then extract the cheapest equivalent expression."""
|
||||||
|
eclass = egraph_saturate(root, pm, max_iters, ctx)
|
||||||
|
|
||||||
|
# build eclass lookup: node -> canonical eclass representative
|
||||||
|
eclass_of: dict[UOp, UOp] = {}
|
||||||
|
for canon, members in eclass.items():
|
||||||
|
for u in members: eclass_of[u] = canon
|
||||||
|
|
||||||
|
all_nodes: list[UOp] = [u for members in eclass.values() for u in members]
|
||||||
|
|
||||||
|
# bottom-up DP: for each eclass, find the cheapest representative
|
||||||
|
cost_of: dict[UOp, tuple[int, UOp]] = {} # eclass_canon -> (cost, best_uop)
|
||||||
|
|
||||||
|
depth_cache: dict[UOp, int] = {}
|
||||||
|
def _depth(u:UOp) -> int:
|
||||||
|
if u in depth_cache: return depth_cache[u]
|
||||||
|
depth_cache[u] = 0 # break cycles
|
||||||
|
depth_cache[u] = (1 + max((_depth(s) for s in u.src), default=0)) if u.src else 0
|
||||||
|
return depth_cache[u]
|
||||||
|
|
||||||
|
for u in sorted(all_nodes, key=_depth):
|
||||||
|
canon = eclass_of[u]
|
||||||
|
child_cost = 0
|
||||||
|
for s in u.src:
|
||||||
|
if (s_canon := eclass_of.get(s)) is not None and s_canon in cost_of: child_cost += cost_of[s_canon][0]
|
||||||
|
else: child_cost += node_cost(s)
|
||||||
|
total = node_cost(u) + child_cost
|
||||||
|
if canon not in cost_of or total < cost_of[canon][0]:
|
||||||
|
cost_of[canon] = (total, u)
|
||||||
|
|
||||||
|
root_canon = eclass_of.get(root)
|
||||||
|
if root_canon is not None and root_canon in cost_of: return _rebuild_tree(cost_of[root_canon][1], eclass_of, cost_of)
|
||||||
|
return root
|
||||||
|
|
||||||
|
def _rebuild_tree(u:UOp, eclass_of:dict[UOp, UOp], cost_of:dict[UOp, tuple[int, UOp]], cache:dict[UOp, UOp]|None=None) -> UOp:
|
||||||
|
"""Recursively rebuild a UOp tree, picking the cheapest representative for each child's eclass."""
|
||||||
|
if not u.src: return u
|
||||||
|
if cache is None: cache = {}
|
||||||
|
new_src = []
|
||||||
|
for s in u.src:
|
||||||
|
s_canon = eclass_of.get(s)
|
||||||
|
if s_canon is not None and s_canon in cost_of:
|
||||||
|
if s_canon in cache: new_src.append(cache[s_canon])
|
||||||
|
else:
|
||||||
|
cache[s_canon] = s # placeholder breaks cycles
|
||||||
|
cache[s_canon] = _rebuild_tree(cost_of[s_canon][1], eclass_of, cost_of, cache)
|
||||||
|
new_src.append(cache[s_canon])
|
||||||
|
else:
|
||||||
|
new_src.append(_rebuild_tree(s, eclass_of, cost_of, cache))
|
||||||
|
new_src_tuple = tuple(new_src)
|
||||||
|
return u if new_src_tuple == u.src else UOp(u.op, u.dtype, new_src_tuple, u.arg, u.tag)
|
||||||
|
|
||||||
|
# *** graph-level rewrite: drop-in replacement for graph_rewrite when EGRAPH is set ***
|
||||||
|
|
||||||
|
def egraph_rewrite(sink:UOp, sym_pm:PatternMatcher, extra_pm:PatternMatcher|None=None, ctx=None, name:str|None=None) -> UOp:
|
||||||
|
"""Replace graph_rewrite(sink, sym+extra, ctx) with e-graph extraction for sym, then greedy for the rest."""
|
||||||
|
combined = sym_pm+extra_pm if extra_pm is not None else sym_pm
|
||||||
|
sink = egraph_extract(sink, combined, ctx=ctx)
|
||||||
|
return graph_rewrite(sink, combined, ctx=ctx, name=name)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue