mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
TrackedPatternMatcher needs to loop [pr] (#7499)
This commit is contained in:
parent
6f93e91deb
commit
d078dcd0c8
1 changed files with 9 additions and 9 deletions
|
|
@ -690,13 +690,13 @@ class TrackedPatternMatcher(PatternMatcher):
|
|||
match_stats[p][2] += time.perf_counter()-st
|
||||
continue
|
||||
match_stats[p][1] += 1
|
||||
if (matches := p.match(uop, {})) and (ret:=(fxn(ctx=ctx, **matches[0]) if has_ctx else fxn(**matches[0]))) is not None:
|
||||
match_stats[p][0] += 1
|
||||
match_stats[p][2] += (et:=time.perf_counter()-st)
|
||||
match_stats[p][3] += et
|
||||
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
|
||||
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].matches.append((uop, ret, p, et))
|
||||
return ret # NOTE: if it returns None, we keep trying to match
|
||||
for match in p.match(uop, {}):
|
||||
if (ret:=(fxn(ctx=ctx, **match) if has_ctx else fxn(**match))) is not None:
|
||||
match_stats[p][0] += 1
|
||||
match_stats[p][3] += (et:=time.perf_counter()-st)
|
||||
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
|
||||
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1][1][-1].matches.append((uop, ret, p, et))
|
||||
return ret # NOTE: if it returns None, we keep trying to match
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: rewrite_stack[-1][1][-1].matches.append((uop, ret, None, 0))
|
||||
return None
|
||||
|
|
@ -717,9 +717,9 @@ if TRACK_MATCH_STATS:
|
|||
ret = [0,0,0.0,0.0]
|
||||
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]):
|
||||
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
|
||||
if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
|
||||
if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {(v[2]+v[3])*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
|
||||
ret = [x+y for x,y in zip(ret, v)]
|
||||
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL")
|
||||
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL")
|
||||
|
||||
# *** simple graph rewrite engine ***
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue