mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
gate_rewri
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5fd81a7f67 |
1 changed files with 27 additions and 25 deletions
|
|
@ -1034,10 +1034,11 @@ class PatternMatcher:
|
|||
def __add__(self, more:PatternMatcher) -> PatternMatcher: return PatternMatcher(self.patterns+more.patterns)
|
||||
|
||||
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
||||
ler = {u.op for u in uop.src}
|
||||
for _,match,early_reject in self.pdict.get(uop.op, []):
|
||||
if not early_reject.issubset(ler): continue
|
||||
if (ret:=match(uop, ctx)) is not None and ret is not uop: return ret
|
||||
if len(pats:=self.pdict.get(uop.op, [])):
|
||||
ler = {u.op for u in uop.src}
|
||||
for _,match,early_reject in pats:
|
||||
if not early_reject.issubset(ler): continue
|
||||
if (ret:=match(uop, ctx)) is not None and ret is not uop: return ret
|
||||
return None
|
||||
|
||||
# *** tracking pattern matcher ***
|
||||
|
|
@ -1119,28 +1120,29 @@ def profile_matches(fxn:Callable):
|
|||
|
||||
class TrackedPatternMatcher(PatternMatcher):
|
||||
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
||||
ret = None
|
||||
ler = {u.op for u in uop.src}
|
||||
for p,match,early_reject in self.pdict.get(uop.op, []):
|
||||
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
||||
st = time.perf_counter()
|
||||
if not early_reject.issubset(ler):
|
||||
if len(pats:=self.pdict.get(uop.op, [])):
|
||||
ret = None
|
||||
ler = {u.op for u in uop.src}
|
||||
for p,match,early_reject in pats:
|
||||
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
||||
st = time.perf_counter()
|
||||
if not early_reject.issubset(ler):
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
continue
|
||||
match_stats[p][1] += 1
|
||||
try: ret = match(uop, ctx)
|
||||
except Exception:
|
||||
if TRACK_MATCH_STATS >= 2 and active_rewrites:
|
||||
active_rewrites[-1].matches.append((uop.trace_num, UOp(Ops.REWRITE_ERROR,src=uop.src,arg=str(sys.exc_info()[1])).trace_num,p.location,0))
|
||||
raise
|
||||
if ret is not None and ret is not uop:
|
||||
match_stats[p][0] += 1
|
||||
match_stats[p][3] += (et:=time.perf_counter()-st)
|
||||
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location))
|
||||
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites:
|
||||
active_rewrites[-1].matches.append((uop.trace_num, ret.trace_num, p.location, et))
|
||||
return ret
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
continue
|
||||
match_stats[p][1] += 1
|
||||
try: ret = match(uop, ctx)
|
||||
except Exception:
|
||||
if TRACK_MATCH_STATS >= 2 and active_rewrites:
|
||||
active_rewrites[-1].matches.append((uop.trace_num, UOp(Ops.REWRITE_ERROR,src=uop.src,arg=str(sys.exc_info()[1])).trace_num,p.location,0))
|
||||
raise
|
||||
if ret is not None and ret is not uop:
|
||||
match_stats[p][0] += 1
|
||||
match_stats[p][3] += (et:=time.perf_counter()-st)
|
||||
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location))
|
||||
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites:
|
||||
active_rewrites[-1].matches.append((uop.trace_num, ret.trace_num, p.location, et))
|
||||
return ret
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
return None
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue