fix viz with untracked graph_rewrite calls (#8298)

* fix viz with untracked graph_rewrite calls

* mark as green
This commit is contained in:
qazal 2024-12-17 23:37:53 +02:00 committed by GitHub
commit 673a76398a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 3 additions and 4 deletions

View file

@ -114,7 +114,6 @@ class TestViz(unittest.TestCase):
self.assertIs(ret[0], a.sqrt().sin()) # only rewrite
# NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ
@unittest.expectedFailure
def test_rewrite_without_context(self):
def untracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
@track_rewrites(named=True)

View file

@ -878,11 +878,11 @@ class TrackedPatternMatcher(PatternMatcher):
match_stats[p][0] += 1
match_stats[p][3] += (et:=time.perf_counter()-st)
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp):
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0:
with Context(PICKLE_BUFFERS=0): tracked_ctxs[-1][-1].matches.append((pickle.dumps(uop), pickle.dumps(ret), p, et))
return ret # NOTE: if it returns None, we keep trying to match
match_stats[p][2] += time.perf_counter()-st
if TRACK_MATCH_STATS >= 2:
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
with Context(PICKLE_BUFFERS=0): tracked_ctxs[-1][-1].matches.append((pickle.dumps(uop), None, None, 0))
return None
@ -934,7 +934,7 @@ class RewriteContext:
return ret
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp:
if TRACK_MATCH_STATS >= 2 and not bottom_up: # TODO: make viz work with bottom_up=True
if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True
with Context(PICKLE_BUFFERS=0):
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), pickle.dumps(sink)))
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).rewrite(sink)