mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
go
This commit is contained in:
parent
a65d9fea74
commit
e7c8aaed31
1 changed files with 10 additions and 16 deletions
|
|
@ -93,8 +93,8 @@ class EGraph:
|
|||
if new_src_tuple == u.src: return u
|
||||
return UOp(u.op, u.dtype, new_src_tuple, u.arg, u.tag)
|
||||
|
||||
def _rebuild(self, dirty:dict[UOp, None], pm:PatternMatcher, ctx=None) -> list[tuple[UOp, UOp]]:
|
||||
"""Rebuild parents of dirty eclasses, creating canonical versions and matching rules."""
|
||||
def _rebuild(self, dirty:dict[UOp, None]) -> list[tuple[UOp, UOp]]:
|
||||
"""Rebuild parents of dirty eclasses, creating canonical versions."""
|
||||
new_equalities: list[tuple[UOp, UOp]] = []
|
||||
affected: dict[UOp, None] = {}
|
||||
for d in dirty:
|
||||
|
|
@ -106,15 +106,12 @@ class EGraph:
|
|||
if rebuilt in self.parent and uf_find(self.parent, rebuilt) is uf_find(self.parent, u): continue
|
||||
self._add_node(rebuilt)
|
||||
new_equalities.append((u, rebuilt))
|
||||
for new in rewrite_all(pm, rebuilt, ctx):
|
||||
if new in self.parent and uf_find(self.parent, new) is uf_find(self.parent, rebuilt): continue
|
||||
self._add_node(new)
|
||||
new_equalities.append((rebuilt, new))
|
||||
return new_equalities
|
||||
|
||||
def egraph_saturate(root:UOp, pm:PatternMatcher, max_iters:int=10, ctx=None) -> dict[UOp, dict[UOp, None]]:
|
||||
"""Build an e-graph with full equality saturation (with rebuilding). Returns eclass map."""
|
||||
eg = EGraph(root)
|
||||
node_limit = len(eg.all_nodes) * 3 # stop growing at 3x initial size to prevent combinatorial blowup
|
||||
worklist: dict[UOp, None] = dict(eg.all_nodes) # nodes to match rules on
|
||||
for _ in range(max_iters):
|
||||
# phase 1: match rules only on worklist nodes
|
||||
|
|
@ -122,6 +119,7 @@ def egraph_saturate(root:UOp, pm:PatternMatcher, max_iters:int=10, ctx=None) ->
|
|||
next_worklist: dict[UOp, None] = {}
|
||||
prev_nodes = dict(eg.all_nodes)
|
||||
for u in list(worklist):
|
||||
if len(eg.all_nodes) >= node_limit: break
|
||||
for new in rewrite_all(pm, u, ctx):
|
||||
if new in eg.parent and uf_find(eg.parent, new) is uf_find(eg.parent, u): continue
|
||||
eg._add_node(new)
|
||||
|
|
@ -129,19 +127,17 @@ def egraph_saturate(root:UOp, pm:PatternMatcher, max_iters:int=10, ctx=None) ->
|
|||
# all newly added nodes (including sub-nodes of rewrite results) go on next worklist
|
||||
for u in eg.all_nodes:
|
||||
if u not in prev_nodes: next_worklist[u] = None
|
||||
if not new_equalities:
|
||||
break
|
||||
if not new_equalities: break
|
||||
|
||||
# phase 2: merge and rebuild until no new merges
|
||||
# phase 2: merge eclasses, then rebuild canonical forms (no rule matching in rebuild)
|
||||
while new_equalities:
|
||||
dirty: dict[UOp, None] = {}
|
||||
for a, b in new_equalities:
|
||||
merged = eg._merge(a, b)
|
||||
if merged is not None: dirty[merged] = None
|
||||
if not dirty: break
|
||||
# phase 3: rebuild parents of dirty eclasses, add rebuilt nodes to next worklist
|
||||
new_equalities = eg._rebuild(dirty, pm, ctx)
|
||||
for a, b in new_equalities: next_worklist[b] = None
|
||||
new_equalities = eg._rebuild(dirty)
|
||||
for _, b in new_equalities: next_worklist[b] = None
|
||||
worklist = next_worklist
|
||||
|
||||
return eg.eclass
|
||||
|
|
@ -226,8 +222,6 @@ def _rebuild_tree(u:UOp, eclass_of:dict[UOp, UOp], cost_of:dict[UOp, tuple[int,
|
|||
|
||||
def egraph_rewrite(sink:UOp, sym_pm:PatternMatcher, extra_pm:PatternMatcher|None=None, ctx=None, name:str|None=None) -> UOp:
|
||||
"""Replace graph_rewrite(sink, sym+extra, ctx) with e-graph extraction for sym, then greedy for the rest."""
|
||||
sink = egraph_extract(sink, sym_pm)
|
||||
# run greedy with the full combined matcher to catch enabling transformations the e-graph skipped
|
||||
combined = sym_pm+extra_pm if extra_pm is not None else sym_pm
|
||||
sink = graph_rewrite(sink, combined, ctx=ctx, name=name)
|
||||
return sink
|
||||
sink = egraph_extract(sink, combined, ctx=ctx)
|
||||
return graph_rewrite(sink, combined, ctx=ctx, name=name)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue