Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
5fd81a7f67 add a gate to rewrite if there's no rules [pr] 2025-11-30 17:28:58 -08:00

View file

@ -1034,10 +1034,11 @@ class PatternMatcher:
def __add__(self, more:PatternMatcher) -> PatternMatcher: return PatternMatcher(self.patterns+more.patterns)
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
ler = {u.op for u in uop.src}
for _,match,early_reject in self.pdict.get(uop.op, []):
if not early_reject.issubset(ler): continue
if (ret:=match(uop, ctx)) is not None and ret is not uop: return ret
if len(pats:=self.pdict.get(uop.op, [])):
ler = {u.op for u in uop.src}
for _,match,early_reject in pats:
if not early_reject.issubset(ler): continue
if (ret:=match(uop, ctx)) is not None and ret is not uop: return ret
return None
# *** tracking pattern matcher ***
@ -1119,28 +1120,29 @@ def profile_matches(fxn:Callable):
class TrackedPatternMatcher(PatternMatcher):
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
ret = None
ler = {u.op for u in uop.src}
for p,match,early_reject in self.pdict.get(uop.op, []):
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
st = time.perf_counter()
if not early_reject.issubset(ler):
if len(pats:=self.pdict.get(uop.op, [])):
ret = None
ler = {u.op for u in uop.src}
for p,match,early_reject in pats:
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
st = time.perf_counter()
if not early_reject.issubset(ler):
match_stats[p][2] += time.perf_counter()-st
continue
match_stats[p][1] += 1
try: ret = match(uop, ctx)
except Exception:
if TRACK_MATCH_STATS >= 2 and active_rewrites:
active_rewrites[-1].matches.append((uop.trace_num, UOp(Ops.REWRITE_ERROR,src=uop.src,arg=str(sys.exc_info()[1])).trace_num,p.location,0))
raise
if ret is not None and ret is not uop:
match_stats[p][0] += 1
match_stats[p][3] += (et:=time.perf_counter()-st)
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location))
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites:
active_rewrites[-1].matches.append((uop.trace_num, ret.trace_num, p.location, et))
return ret
match_stats[p][2] += time.perf_counter()-st
continue
match_stats[p][1] += 1
try: ret = match(uop, ctx)
except Exception:
if TRACK_MATCH_STATS >= 2 and active_rewrites:
active_rewrites[-1].matches.append((uop.trace_num, UOp(Ops.REWRITE_ERROR,src=uop.src,arg=str(sys.exc_info()[1])).trace_num,p.location,0))
raise
if ret is not None and ret is not uop:
match_stats[p][0] += 1
match_stats[p][3] += (et:=time.perf_counter()-st)
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location))
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites:
active_rewrites[-1].matches.append((uop.trace_num, ret.trace_num, p.location, et))
return ret
match_stats[p][2] += time.perf_counter()-st
return None
@dataclass(frozen=True)