mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
gate_rewri
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5fd81a7f67 |
1 changed files with 27 additions and 25 deletions
|
|
@ -1034,10 +1034,11 @@ class PatternMatcher:
|
||||||
def __add__(self, more:PatternMatcher) -> PatternMatcher: return PatternMatcher(self.patterns+more.patterns)
|
def __add__(self, more:PatternMatcher) -> PatternMatcher: return PatternMatcher(self.patterns+more.patterns)
|
||||||
|
|
||||||
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
||||||
ler = {u.op for u in uop.src}
|
if len(pats:=self.pdict.get(uop.op, [])):
|
||||||
for _,match,early_reject in self.pdict.get(uop.op, []):
|
ler = {u.op for u in uop.src}
|
||||||
if not early_reject.issubset(ler): continue
|
for _,match,early_reject in pats:
|
||||||
if (ret:=match(uop, ctx)) is not None and ret is not uop: return ret
|
if not early_reject.issubset(ler): continue
|
||||||
|
if (ret:=match(uop, ctx)) is not None and ret is not uop: return ret
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# *** tracking pattern matcher ***
|
# *** tracking pattern matcher ***
|
||||||
|
|
@ -1119,28 +1120,29 @@ def profile_matches(fxn:Callable):
|
||||||
|
|
||||||
class TrackedPatternMatcher(PatternMatcher):
|
class TrackedPatternMatcher(PatternMatcher):
|
||||||
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
||||||
ret = None
|
if len(pats:=self.pdict.get(uop.op, [])):
|
||||||
ler = {u.op for u in uop.src}
|
ret = None
|
||||||
for p,match,early_reject in self.pdict.get(uop.op, []):
|
ler = {u.op for u in uop.src}
|
||||||
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
for p,match,early_reject in pats:
|
||||||
st = time.perf_counter()
|
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
||||||
if not early_reject.issubset(ler):
|
st = time.perf_counter()
|
||||||
|
if not early_reject.issubset(ler):
|
||||||
|
match_stats[p][2] += time.perf_counter()-st
|
||||||
|
continue
|
||||||
|
match_stats[p][1] += 1
|
||||||
|
try: ret = match(uop, ctx)
|
||||||
|
except Exception:
|
||||||
|
if TRACK_MATCH_STATS >= 2 and active_rewrites:
|
||||||
|
active_rewrites[-1].matches.append((uop.trace_num, UOp(Ops.REWRITE_ERROR,src=uop.src,arg=str(sys.exc_info()[1])).trace_num,p.location,0))
|
||||||
|
raise
|
||||||
|
if ret is not None and ret is not uop:
|
||||||
|
match_stats[p][0] += 1
|
||||||
|
match_stats[p][3] += (et:=time.perf_counter()-st)
|
||||||
|
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location))
|
||||||
|
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites:
|
||||||
|
active_rewrites[-1].matches.append((uop.trace_num, ret.trace_num, p.location, et))
|
||||||
|
return ret
|
||||||
match_stats[p][2] += time.perf_counter()-st
|
match_stats[p][2] += time.perf_counter()-st
|
||||||
continue
|
|
||||||
match_stats[p][1] += 1
|
|
||||||
try: ret = match(uop, ctx)
|
|
||||||
except Exception:
|
|
||||||
if TRACK_MATCH_STATS >= 2 and active_rewrites:
|
|
||||||
active_rewrites[-1].matches.append((uop.trace_num, UOp(Ops.REWRITE_ERROR,src=uop.src,arg=str(sys.exc_info()[1])).trace_num,p.location,0))
|
|
||||||
raise
|
|
||||||
if ret is not None and ret is not uop:
|
|
||||||
match_stats[p][0] += 1
|
|
||||||
match_stats[p][3] += (et:=time.perf_counter()-st)
|
|
||||||
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location))
|
|
||||||
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites:
|
|
||||||
active_rewrites[-1].matches.append((uop.trace_num, ret.trace_num, p.location, et))
|
|
||||||
return ret
|
|
||||||
match_stats[p][2] += time.perf_counter()-st
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue