explicit fixed point rewrite (#11685)

* explicit fixed point rewrite

* local cache

* fix that
This commit is contained in:
George Hotz 2025-08-15 11:08:41 -07:00 committed by GitHub
commit 4ab9fb2edd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -756,16 +756,6 @@ class PatternMatcher:
if (ret:=match(uop, ctx)) is not None and ret is not uop: return ret
return None
def fixed_point_rewrite(self, uop:UOp, ctx=None) -> UOp:
# apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match
new_n: UOp|None = uop
seen = set()
while new_n is not None:
if new_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite")
seen.add(new_n)
last_n, new_n = new_n, self.rewrite(new_n, ctx)
return last_n
# *** non-blocking UOp tracker ***
ucount = itertools.count()
@ -906,10 +896,21 @@ class RewriteNotReady(Exception): pass
class RewriteContext:
def __init__(self, pm, bpm, ctx=None):
self.pm: PatternMatcher|None = pm
self.pm_cache: dict[UOp, UOp|None] = {}
self.bpm: PatternMatcher|None = bpm
self.bpm_cache: dict[UOp, UOp|None] = {}
self.ctx = ctx
self.replace: dict[UOp, UOp] = {}
self.skip_0: dict[UOp, None] = {} # NOTE: this is needed for RewriteNotReady. it also detects some infinite loops
def cached_pm_rewrite(self, x:UOp):
if (ret:=self.pm_cache.get(x,False)) is not False: return ret
ret = self.pm_cache[x] = cast(PatternMatcher, self.pm).rewrite(x, self.ctx)
return ret
def cached_bpm_rewrite(self, x:UOp):
if (ret:=self.bpm_cache.get(x,False)) is not False: return ret
ret = self.bpm_cache[x] = cast(PatternMatcher, self.bpm).rewrite(x, self.ctx)
return ret
def unified_rewrite(self, root:UOp) -> UOp:
stack: list[tuple[UOp, int, UOp]] = [(root, 0, root)]
@ -919,18 +920,23 @@ class RewriteContext:
if n in self.replace: continue # skip any nodes we have seen
try:
if stage == 0:
if n in self.skip_0: continue
# if bottom up, we rewrite this node early. in both cases, we add its parents to the stack
if self.bpm is not None: new_n = self.bpm.fixed_point_rewrite(new_n, self.ctx)
if self.bpm is not None:
# apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match
test_n: UOp|None = n
seen = set()
while test_n is not None:
if test_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite")
seen.add(test_n)
new_n, test_n = test_n, self.cached_bpm_rewrite(test_n)
stack.append((n, 1, new_n))
for x in reversed(new_n.src): stack.append((x, 0, x))
self.skip_0[n] = None
elif stage == 1:
try: new_src = tuple([self.replace[x] for x in new_n.src])
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
if new_src == new_n.src:
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
if self.pm is None or (new_src_n:=self.pm.rewrite(new_n, self.ctx)) is None:
if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None:
self.replace[n] = new_n
continue
else:
@ -942,7 +948,7 @@ class RewriteContext:
else:
# in stage 2, we link the result of new_n to the result of n
try: self.replace[n] = self.replace[new_n]
except KeyError: raise RuntimeError("infinite loop in graph_rewrite (explicit)") # pylint: disable=raise-missing-from
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
except RewriteNotReady:
# retry this later
stack.insert(0, (n, stage, new_n))