tinygrad/test/null/test_egraph.py
George Hotz 29b2afa0cb fix cycle
2026-02-09 09:50:55 +08:00

502 lines
19 KiB
Python

import unittest
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import Ops, UOp, GroupOp, PatternMatcher, UPat, graph_rewrite
from tinygrad.uop.egraph import uf_find, uf_union, rewrite_all, EGraph, egraph_saturate, egraph_extract, node_cost, _rebuild_tree
# *** test union-find ***
class TestUnionFind(unittest.TestCase):
def test_find_self(self):
a, b = UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)
parent = {a: a, b: b}
self.assertIs(uf_find(parent, a), a)
self.assertIs(uf_find(parent, b), b)
def test_union_basic(self):
a, b = UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)
parent = {a: a, b: b}
size = {a: 1, b: 1}
root = uf_union(parent, size, a, b)
self.assertIs(uf_find(parent, a), uf_find(parent, b))
self.assertIs(root, uf_find(parent, a))
def test_union_chain(self):
a, b, c = UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2), UOp.const(dtypes.int, 3)
parent = {a: a, b: b, c: c}
size = {a: 1, b: 1, c: 1}
uf_union(parent, size, a, b)
uf_union(parent, size, b, c)
self.assertIs(uf_find(parent, a), uf_find(parent, c))
def test_union_idempotent(self):
a, b = UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)
parent = {a: a, b: b}
size = {a: 1, b: 1}
r1 = uf_union(parent, size, a, b)
r2 = uf_union(parent, size, a, b)
self.assertIs(r1, r2)
# *** test rewrite_all ***
class TestRewriteAll(unittest.TestCase):
def test_single_match(self):
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
results = rewrite_all(pm, a + 0)
self.assertEqual(len(results), 1)
self.assertIs(results[0], a)
def test_no_match(self):
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
results = rewrite_all(pm, a + b)
self.assertEqual(len(results), 0)
def test_multiple_matches(self):
pm = PatternMatcher([
(UPat.var("x") + 0, lambda x: x),
(UPat.var("x") * 1, lambda x: x),
])
a = UOp.variable("a", 0, 10)
results = rewrite_all(pm, a + 0)
self.assertEqual(len(results), 1)
self.assertIs(results[0], a)
def test_both_rules_fire(self):
pm = PatternMatcher([
(UPat.var("x") + UPat.var("x"), lambda x: x * 2),
(UPat.var("x") + UPat.var("x"), lambda x: UOp(Ops.SHL, x.dtype, (x, x.const_like(1)))),
])
a = UOp.variable("a", 0, 10)
results = rewrite_all(pm, a + a)
self.assertEqual(len(results), 2)
def test_const_folding(self):
pm = PatternMatcher([
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name="a"),
lambda a: a.const_like(a.src[0].arg + a.src[1].arg) if a.op is Ops.ADD else None),
])
results = rewrite_all(pm, UOp.const(dtypes.int, 3) + UOp.const(dtypes.int, 4))
self.assertEqual(len(results), 1)
self.assertEqual(results[0].arg, 7)
# *** test EGraph class ***
class TestEGraphClass(unittest.TestCase):
def test_init(self):
a = UOp.variable("a", 0, 10)
expr = a + 0
eg = EGraph(expr)
self.assertEqual(len(eg.eclass), len(list(expr.toposort())))
self.assertIn(expr, eg.all_nodes)
def test_add_node(self):
a = UOp.variable("a", 0, 10)
eg = EGraph(a)
b = UOp.variable("b", 0, 10)
eg._add_node(b)
self.assertIn(b, eg.all_nodes)
def test_merge(self):
a = UOp.variable("a", 0, 10)
expr = a + 0
eg = EGraph(expr)
result = eg._merge(expr, a)
self.assertIsNotNone(result)
self.assertIs(uf_find(eg.parent, expr), uf_find(eg.parent, a))
def test_merge_idempotent(self):
a = UOp.variable("a", 0, 10)
eg = EGraph(a)
result = eg._merge(a, a)
self.assertIsNone(result)
# *** test egraph_saturate ***
class TestEGraphSaturate(unittest.TestCase):
def test_identity_rules(self):
pm = PatternMatcher([
(UPat.var("x") + 0, lambda x: x),
(UPat.var("x") * 1, lambda x: x),
])
a = UOp.variable("a", 0, 10)
expr = a + 0
eclass = egraph_saturate(expr, pm)
# a+0 and a should be in the same e-class
a_class = expr_class = None
for canon, members in eclass.items():
if a in members: a_class = canon
if expr in members: expr_class = canon
self.assertIsNotNone(a_class)
self.assertIsNotNone(expr_class)
self.assertIs(a_class, expr_class)
def test_const_fold_saturation(self):
from tinygrad.uop.symbolic import symbolic_simple
c2, c3 = UOp.const(dtypes.int, 2), UOp.const(dtypes.int, 3)
expr = c2 + c3
eclass = egraph_saturate(expr, symbolic_simple)
c5 = UOp.const(dtypes.int, 5)
for canon, members in eclass.items():
if expr in members:
self.assertIn(c5, members, f"expected CONST(5) in eclass of 2+3, got {members}")
return
self.fail("expr not found in any eclass")
def test_no_rules_match(self):
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
eclass = egraph_saturate(a + b, pm)
for canon, members in eclass.items():
self.assertEqual(len(members), 1)
def test_max_iters_respected(self):
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
expr = a + 0
eclass = egraph_saturate(expr, pm, max_iters=1)
a_class = expr_class = None
for canon, members in eclass.items():
if a in members: a_class = canon
if expr in members: expr_class = canon
self.assertIs(a_class, expr_class)
def test_rebuilding_propagates(self):
"""After a*0 merges with 0, rebuilding should create (0+a) which then matches x+0 -> x."""
pm = PatternMatcher([
(UPat.var("x") * 0, lambda x: x.const_like(0)),
(UPat.var("x") + 0, lambda x: x),
])
a = UOp.variable("a", 0, 10)
expr = (a * 0) + a
eclass = egraph_saturate(expr, pm)
expr_cls = a_cls = None
for canon, members in eclass.items():
if expr in members: expr_cls = canon
if a in members: a_cls = canon
self.assertIsNotNone(expr_cls)
self.assertIsNotNone(a_cls)
self.assertIs(expr_cls, a_cls)
def test_rebuilding_chain(self):
"""((a*0)+0)+b should simplify to b through multiple rebuild steps."""
pm = PatternMatcher([
(UPat.var("x") * 0, lambda x: x.const_like(0)),
(UPat.var("x") + 0, lambda x: x),
])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
expr = ((a * 0) + 0) + b
eclass = egraph_saturate(expr, pm)
expr_cls = b_cls = None
for canon, members in eclass.items():
if expr in members: expr_cls = canon
if b in members: b_cls = canon
self.assertIsNotNone(expr_cls)
self.assertIsNotNone(b_cls)
self.assertIs(expr_cls, b_cls)
# *** test egraph_extract ***
class TestEGraphExtract(unittest.TestCase):
def test_extract_identity(self):
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
self.assertIs(egraph_extract(a + 0, pm), a)
def test_extract_mul_identity(self):
pm = PatternMatcher([(UPat.var("x") * 1, lambda x: x)])
a = UOp.variable("a", 0, 10)
self.assertIs(egraph_extract(a * 1, pm), a)
def test_extract_const_fold(self):
from tinygrad.uop.symbolic import symbolic_simple
result = egraph_extract(UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 3), symbolic_simple)
self.assertEqual(result.op, Ops.CONST)
self.assertEqual(result.arg, 5)
def test_extract_chain(self):
pm = PatternMatcher([
(UPat.var("x") + 0, lambda x: x),
(UPat.var("x") * 1, lambda x: x),
])
a = UOp.variable("a", 0, 10)
self.assertIs(egraph_extract((a + 0) * 1, pm), a)
def test_extract_no_change(self):
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
self.assertIs(egraph_extract(a + b, pm), a + b)
def test_extract_prefers_cheaper(self):
pm = PatternMatcher([(UPat.var("x") + UPat.var("x"), lambda x: x * 2)])
a = UOp.variable("a", 0, 10)
result = egraph_extract(a + a, pm)
self.assertEqual(result.op, Ops.ADD) # ADD cost 1 < MUL cost 2
def test_extract_with_symbolic_simple(self):
from tinygrad.uop.symbolic import symbolic_simple
a = UOp.variable("a", 0, 10)
self.assertIs(egraph_extract((a + 0) * 1, symbolic_simple), a)
def test_combine_terms(self):
from tinygrad.uop.symbolic import symbolic
a = UOp.variable("a", 0, 10)
result = egraph_extract(a * 3 + a * 4, symbolic)
self.assertEqual(result.op, Ops.MUL)
self.assertEqual(result.src[1].arg, 7)
# *** tests that REQUIRE rebuilding ***
def test_rebuild_mul_zero_plus(self):
pm = PatternMatcher([
(UPat.var("x") * 0, lambda x: x.const_like(0)),
(UPat.var("x") + 0, lambda x: x),
])
a = UOp.variable("a", 0, 10)
self.assertIs(egraph_extract((a * 0) + a, pm), a)
def test_rebuild_nested_zero(self):
pm = PatternMatcher([
(UPat.var("x") * 0, lambda x: x.const_like(0)),
(UPat.var("x") + 0, lambda x: x),
])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
self.assertIs(egraph_extract(((a * 0) + 0) + b, pm), b)
def test_rebuild_distribute_then_fold(self):
pm = PatternMatcher([(UPat.var("x") * 0, lambda x: x.const_like(0))])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
result = egraph_extract((a + b) * 0, pm)
self.assertEqual(result.op, Ops.CONST)
self.assertEqual(result.arg, 0)
def test_rebuild_symmetric(self):
pm = PatternMatcher([
(UPat.var("x") * 0, lambda x: x.const_like(0)),
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name="a"),
lambda a: a.const_like(a.src[0].arg + a.src[1].arg) if a.op is Ops.ADD else None),
])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
result = egraph_extract((a * 0) + (b * 0), pm)
self.assertEqual(result.op, Ops.CONST)
self.assertEqual(result.arg, 0)
def test_rebuild_with_real_rules(self):
from tinygrad.uop.symbolic import symbolic_simple
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
self.assertIs(egraph_extract((a * 0) + (b * 1), symbolic_simple), b)
def test_rebuild_deep_chain(self):
pm = PatternMatcher([
(UPat.var("x") * 0, lambda x: x.const_like(0)),
(UPat.var("x") + 0, lambda x: x),
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name="a"),
lambda a: a.const_like(a.src[0].arg + a.src[1].arg) if a.op is Ops.ADD else None),
])
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
c = UOp.variable("c", 0, 10)
self.assertIs(egraph_extract(((a * 0) + (b * 0)) + c, pm), c)
# *** test cost model ***
class TestCostModel(unittest.TestCase):
def test_const_is_free(self):
self.assertEqual(node_cost(UOp.const(dtypes.int, 0)), 0)
def test_add_is_cheap(self):
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
self.assertEqual(node_cost(a + b), 1)
def test_div_is_expensive(self):
a = UOp.variable("a", 0, 10).cast(dtypes.index)
b = UOp.variable("b", 1, 10).cast(dtypes.index)
self.assertEqual(node_cost(a // b), 5)
def test_mul_more_than_add(self):
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
self.assertGreater(node_cost(a * b), node_cost(a + b))
# *** test e-graph matches greedy rewrite ***
class TestEGraphVsGreedy(unittest.TestCase):
def test_matches_greedy_identity(self):
from tinygrad.uop.ops import graph_rewrite
from tinygrad.uop.symbolic import symbolic_simple
a = UOp.variable("a", 0, 10)
greedy = graph_rewrite(a + 0, symbolic_simple)
egraph = egraph_extract(a + 0, symbolic_simple)
self.assertIs(greedy, egraph)
def test_matches_greedy_const_fold(self):
from tinygrad.uop.ops import graph_rewrite
from tinygrad.uop.symbolic import symbolic_simple
expr = UOp.const(dtypes.int, 10) + UOp.const(dtypes.int, 20)
greedy = graph_rewrite(expr, symbolic_simple)
egraph = egraph_extract(expr, symbolic_simple)
self.assertEqual(greedy.op, Ops.CONST)
self.assertEqual(egraph.op, Ops.CONST)
self.assertEqual(greedy.arg, egraph.arg)
def test_matches_greedy_double_identity(self):
from tinygrad.uop.ops import graph_rewrite
from tinygrad.uop.symbolic import symbolic_simple
a = UOp.variable("a", 0, 10)
expr = (a + 0) * 1
self.assertIs(graph_rewrite(expr, symbolic_simple), a)
self.assertIs(egraph_extract(expr, symbolic_simple), a)
# *** test e-graph beats greedy (phase-ordering problems) ***
# helper PMs that create phase-ordering traps
_pm_strength_reduce = PatternMatcher([
# strength reduction x*2 -> x+x fires FIRST and destroys the x*c form needed by combine-terms
(UPat.var('x') * UPat.cvar('c', vec=False), lambda x,c: x+x if c.arg == 2 else None),
# combine terms: x*c0 + x*c1 -> x*(c0+c1) can only match if both sides are x*c
(UPat.var('x') * UPat.cvar('c0') + UPat.var('x') * UPat.cvar('c1'), lambda x,c0,c1: x*(c0+c1)),
# constant folding
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name='a'),
lambda a: a.const_like(a.src[0].arg + a.src[1].arg) if a.op is Ops.ADD else
a.const_like(a.src[0].arg * a.src[1].arg) if a.op is Ops.MUL else None),
(UPat.var('x') + 0, lambda x: x),
(UPat.var('x') * 1, lambda x: x),
])
_pm_shift_reduce = PatternMatcher([
# strength reduction x*2 -> x<<1 fires FIRST and destroys the x*c form
(UPat.var('x') * UPat.cvar('c', vec=False),
lambda x,c: UOp(Ops.SHL, x.dtype, (x, x.const_like(1))) if c.arg == 2 else None),
(UPat.var('x') * UPat.cvar('c0') + UPat.var('x') * UPat.cvar('c1'), lambda x,c0,c1: x*(c0+c1)),
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name='a'),
lambda a: a.const_like(a.src[0].arg + a.src[1].arg) if a.op is Ops.ADD else
a.const_like(a.src[0].arg * a.src[1].arg) if a.op is Ops.MUL else None),
(UPat.var('x') + 0, lambda x: x),
(UPat.var('x') * 1, lambda x: x),
])
_pm_strength_fold = PatternMatcher([
# strength reduction x*2 -> x+x blocks two-stage folding (x*c1)*c2 -> x*(c1*c2)
(UPat.var('x') * UPat.cvar('c', vec=False), lambda x,c: x+x if c.arg == 2 else None),
((UPat.var('x') * UPat.cvar('c1')) * UPat.cvar('c2'), lambda x,c1,c2: x*(c1*c2)),
(UPat(GroupOp.Binary, src=(UPat((Ops.CONST, Ops.VCONST)),)*2, name='a'),
lambda a: a.const_like(a.src[0].arg * a.src[1].arg) if a.op is Ops.MUL else None),
])
def _total_cost(u:UOp) -> int:
return sum(node_cost(n) for n in u.toposort())
class TestEGraphBeatsGreedy(unittest.TestCase):
"""Tests where the e-graph finds a cheaper result than the greedy rewriter due to phase-ordering.
The core problem: when Rule A fires first and transforms a node, it can destroy the pattern
that Rule B needs to match. Rule B would have led to a cheaper result, but the greedy rewriter
never tries it. The e-graph explores BOTH paths and picks the cheapest.
"""
def test_strength_reduce_blocks_combine(self):
"""a*2 + a*3: strength reduction x*2->x+x destroys the x*c form needed by combine-terms x*c0+x*c1->x*(c0+c1)."""
a = UOp.variable("a", 0, 10)
expr = a * 2 + a * 3
greedy = graph_rewrite(expr, _pm_strength_reduce)
egraph = egraph_extract(expr, _pm_strength_reduce)
# greedy: (a+a) + a*3 (cost 4) — strength reduction destroyed the a*2 pattern
self.assertEqual(greedy.op, Ops.ADD)
self.assertGreater(_total_cost(greedy), _total_cost(egraph))
# egraph: a*5 (cost 2) — combine-terms wins because the e-graph explored both paths
self.assertEqual(egraph.op, Ops.MUL)
self.assertEqual(egraph.src[1].arg, 5)
def test_shift_reduce_blocks_combine(self):
"""a*2 + a*3: shift reduction x*2->x<<1 also destroys the combine-terms pattern."""
a = UOp.variable("a", 0, 10)
expr = a * 2 + a * 3
greedy = graph_rewrite(expr, _pm_shift_reduce)
egraph = egraph_extract(expr, _pm_shift_reduce)
self.assertEqual(greedy.op, Ops.ADD)
self.assertGreater(_total_cost(greedy), _total_cost(egraph))
self.assertEqual(egraph.op, Ops.MUL)
self.assertEqual(egraph.src[1].arg, 5)
def test_strength_reduce_chain(self):
"""a*2 + a*3 + a*4: strength reduction causes greedy to miss the combined a*9."""
a = UOp.variable("a", 0, 10)
expr = a * 2 + a * 3 + a * 4
greedy = graph_rewrite(expr, _pm_strength_reduce)
egraph = egraph_extract(expr, _pm_strength_reduce)
self.assertGreater(_total_cost(greedy), _total_cost(egraph))
def test_strength_reduce_blocks_two_stage_fold(self):
"""(a*2)*3: strength reduction x*2->x+x blocks two-stage constant folding (x*c1)*c2->x*(c1*c2)."""
a = UOp.variable("a", 0, 10)
expr = (a * 2) * 3
greedy = graph_rewrite(expr, _pm_strength_fold)
egraph = egraph_extract(expr, _pm_strength_fold)
# greedy: (a+a)*3 (cost 3) — can't fold constants because *2 was rewritten to +
self.assertGreater(_total_cost(greedy), _total_cost(egraph))
# egraph: a*6 (cost 2) — two-stage folding path was explored
self.assertEqual(egraph.op, Ops.MUL)
self.assertEqual(egraph.src[1].arg, 6)
def test_both_sides_strength_reduced(self):
"""a*2 + a*2: both sides get strength-reduced, blocking combine-terms."""
a = UOp.variable("a", 0, 10)
expr = a * 2 + a * 2
greedy = graph_rewrite(expr, _pm_strength_reduce)
egraph = egraph_extract(expr, _pm_strength_reduce)
# greedy: (a+a)+(a+a) — both a*2 were rewritten before combine could fire
# egraph: a*4 — combine-terms path was found
self.assertEqual(egraph.op, Ops.MUL)
self.assertEqual(egraph.src[1].arg, 4)
# both have cost 2 here (shared subexpression), but egraph result is canonical
self.assertLessEqual(_total_cost(egraph), _total_cost(greedy))
# *** test cycle-breaking in extraction ***
class TestExtractionCycles(unittest.TestCase):
def test_self_referencing_eclass(self):
"""x+0 -> x merges x+0 into x's eclass. Extraction must not recurse on the self-reference."""
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
self.assertIs(egraph_extract(a + 0, pm), a)
def test_nested_self_referencing_eclass(self):
"""((a+0)+0)+0 — all merge into a's eclass. Deep self-reference chain."""
pm = PatternMatcher([(UPat.var("x") + 0, lambda x: x)])
a = UOp.variable("a", 0, 10)
self.assertIs(egraph_extract(((a + 0) + 0) + 0, pm), a)
def test_mutual_eclass_cycle(self):
"""Two eclasses whose best nodes reference each other — extraction must terminate via cycle-breaking cache."""
x = UOp.variable("x", 0, 10)
y = UOp.variable("y", 0, 10)
one = UOp.const(dtypes.index, 1)
two = UOp.const(dtypes.index, 2)
node1 = x + one # E1's best, child x is in E2
node2 = y + two # E2's best, child y is in E1
eclass_of = {node1: node1, x: node2, node2: node2, y: node1, one: one, two: two}
cost_of = {node1: (2, node1), node2: (2, node2), one: (0, one), two: (0, two)}
# without cycle-breaking cache, this would recurse: E1->E2->E1->...
result = _rebuild_tree(node1, eclass_of, cost_of)
self.assertIsNotNone(result) # just verify it terminates
def test_mutual_rewrite_cycle(self):
"""x+x <-> x*2 mutual rewrite. Both forms in same eclass, extraction picks cheaper (ADD)."""
pm = PatternMatcher([
(UPat.var("x") + UPat.var("x"), lambda x: x * 2),
(UPat.var("x") * UPat.cvar("c", vec=False), lambda x,c: x+x if c.arg == 2 else None),
])
a = UOp.variable("a", 0, 10)
result = egraph_extract(a + a, pm)
self.assertEqual(result.op, Ops.ADD) # ADD cost 1 < MUL cost 2
if __name__ == '__main__':
unittest.main(verbosity=2)