mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
explicit fixed point rewrite (#11685)
* explicit fixed point rewrite * local cache * fix that
This commit is contained in:
parent
5d6963c968
commit
4ab9fb2edd
1 changed files with 22 additions and 16 deletions
|
|
@ -756,16 +756,6 @@ class PatternMatcher:
|
|||
if (ret:=match(uop, ctx)) is not None and ret is not uop: return ret
|
||||
return None
|
||||
|
||||
def fixed_point_rewrite(self, uop:UOp, ctx=None) -> UOp:
|
||||
# apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match
|
||||
new_n: UOp|None = uop
|
||||
seen = set()
|
||||
while new_n is not None:
|
||||
if new_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite")
|
||||
seen.add(new_n)
|
||||
last_n, new_n = new_n, self.rewrite(new_n, ctx)
|
||||
return last_n
|
||||
|
||||
# *** non-blocking UOp tracker ***
|
||||
|
||||
ucount = itertools.count()
|
||||
|
|
@ -906,10 +896,21 @@ class RewriteNotReady(Exception): pass
|
|||
class RewriteContext:
|
||||
def __init__(self, pm, bpm, ctx=None):
|
||||
self.pm: PatternMatcher|None = pm
|
||||
self.pm_cache: dict[UOp, UOp|None] = {}
|
||||
self.bpm: PatternMatcher|None = bpm
|
||||
self.bpm_cache: dict[UOp, UOp|None] = {}
|
||||
self.ctx = ctx
|
||||
self.replace: dict[UOp, UOp] = {}
|
||||
self.skip_0: dict[UOp, None] = {} # NOTE: this is needed for RewriteNotReady. it also detects some infinite loops
|
||||
|
||||
def cached_pm_rewrite(self, x:UOp):
|
||||
if (ret:=self.pm_cache.get(x,False)) is not False: return ret
|
||||
ret = self.pm_cache[x] = cast(PatternMatcher, self.pm).rewrite(x, self.ctx)
|
||||
return ret
|
||||
|
||||
def cached_bpm_rewrite(self, x:UOp):
|
||||
if (ret:=self.bpm_cache.get(x,False)) is not False: return ret
|
||||
ret = self.bpm_cache[x] = cast(PatternMatcher, self.bpm).rewrite(x, self.ctx)
|
||||
return ret
|
||||
|
||||
def unified_rewrite(self, root:UOp) -> UOp:
|
||||
stack: list[tuple[UOp, int, UOp]] = [(root, 0, root)]
|
||||
|
|
@ -919,18 +920,23 @@ class RewriteContext:
|
|||
if n in self.replace: continue # skip any nodes we have seen
|
||||
try:
|
||||
if stage == 0:
|
||||
if n in self.skip_0: continue
|
||||
# if bottom up, we rewrite this node early. in both cases, we add its parents to the stack
|
||||
if self.bpm is not None: new_n = self.bpm.fixed_point_rewrite(new_n, self.ctx)
|
||||
if self.bpm is not None:
|
||||
# apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match
|
||||
test_n: UOp|None = n
|
||||
seen = set()
|
||||
while test_n is not None:
|
||||
if test_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite")
|
||||
seen.add(test_n)
|
||||
new_n, test_n = test_n, self.cached_bpm_rewrite(test_n)
|
||||
stack.append((n, 1, new_n))
|
||||
for x in reversed(new_n.src): stack.append((x, 0, x))
|
||||
self.skip_0[n] = None
|
||||
elif stage == 1:
|
||||
try: new_src = tuple([self.replace[x] for x in new_n.src])
|
||||
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
|
||||
if new_src == new_n.src:
|
||||
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
|
||||
if self.pm is None or (new_src_n:=self.pm.rewrite(new_n, self.ctx)) is None:
|
||||
if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None:
|
||||
self.replace[n] = new_n
|
||||
continue
|
||||
else:
|
||||
|
|
@ -942,7 +948,7 @@ class RewriteContext:
|
|||
else:
|
||||
# in stage 2, we link the result of new_n to the result of n
|
||||
try: self.replace[n] = self.replace[new_n]
|
||||
except KeyError: raise RuntimeError("infinite loop in graph_rewrite (explicit)") # pylint: disable=raise-missing-from
|
||||
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
|
||||
except RewriteNotReady:
|
||||
# retry this later
|
||||
stack.insert(0, (n, stage, new_n))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue