mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix viz with untracked graph_rewrite calls (#8298)
* fix viz with untracked graph_rewrite calls * mark as green
This commit is contained in:
parent
5977a3d8a6
commit
673a76398a
2 changed files with 3 additions and 4 deletions
|
|
@ -114,7 +114,6 @@ class TestViz(unittest.TestCase):
|
|||
self.assertIs(ret[0], a.sqrt().sin()) # only rewrite
|
||||
|
||||
# NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ
|
||||
@unittest.expectedFailure
|
||||
def test_rewrite_without_context(self):
|
||||
def untracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
|
||||
@track_rewrites(named=True)
|
||||
|
|
|
|||
|
|
@ -878,11 +878,11 @@ class TrackedPatternMatcher(PatternMatcher):
|
|||
match_stats[p][0] += 1
|
||||
match_stats[p][3] += (et:=time.perf_counter()-st)
|
||||
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
|
||||
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp):
|
||||
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0:
|
||||
with Context(PICKLE_BUFFERS=0): tracked_ctxs[-1][-1].matches.append((pickle.dumps(uop), pickle.dumps(ret), p, et))
|
||||
return ret # NOTE: if it returns None, we keep trying to match
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
|
||||
with Context(PICKLE_BUFFERS=0): tracked_ctxs[-1][-1].matches.append((pickle.dumps(uop), None, None, 0))
|
||||
return None
|
||||
|
||||
|
|
@ -934,7 +934,7 @@ class RewriteContext:
|
|||
return ret
|
||||
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2 and not bottom_up: # TODO: make viz work with bottom_up=True
|
||||
if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True
|
||||
with Context(PICKLE_BUFFERS=0):
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), pickle.dumps(sink)))
|
||||
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).rewrite(sink)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue