inline viz get_name [pr] (#10682)

* inline viz get_name [pr]

* changing name_fxn makes this simpler

* waitUntil dom
This commit is contained in:
qazal 2025-06-07 11:16:16 +03:00 committed by GitHub
commit b515d796fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 4 additions and 8 deletions

View file

@ -93,7 +93,7 @@ class TestViz(unittest.TestCase):
self.assertEqual(len(ret), 1)
def test_track_rewrites_name_fxn(self):
@track_rewrites(name_fxn=lambda r: f"output_{r}")
@track_rewrites(name_fxn=lambda _,ret: f"output_{ret}")
def do_rewrite(x:UOp):
x = graph_rewrite(x, symbolic)
return x.render()

View file

@ -14,7 +14,7 @@ async function main() {
try {
browser = await puppeteer.launch({ headless: true });
const page = await browser.newPage();
const res = await page.goto("http://localhost:8000");
const res = await page.goto("http://localhost:8000", { waitUntil:"domcontentloaded" });
if (res.status() !== 200) throw new Error("Failed to load page");
const scheduleSelector = await page.waitForSelector("ul");
scheduleSelector.click();

View file

@ -501,10 +501,6 @@ do_fuse = PatternMatcher([
(UPat(Ops.REDUCE_AXIS, name="root"), fuse_arange),
])
def get_name(becomes_map:dict[UOp, UOp]) -> str:
assigned_kernels = {u.base.buf_uop:u.base.src[1] for u in becomes_map.values() if u.base.op is Ops.ASSIGN}.values()
return f"Schedule {pluralize('Kernel', len(set(assigned_kernels)))}"
add_gbarrier = PatternMatcher([(UPat(GroupOp.All-{Ops.GBARRIER, Ops.ASSIGN}, name="x"),
lambda ctx,x: x.replace(tag=1).gbarrier() if x in ctx and x.tag is None else None)])
@ -538,7 +534,7 @@ finalize_gbarrier = PatternMatcher([
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
@track_rewrites(name_fxn=get_name)
@track_rewrites(name_fxn=lambda big_sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[big_sink].toposort() if u.op is Ops.KERNEL]))}")
def get_kernelize_map(big_sink:UOp) -> dict[UOp, UOp]:
# multi + merge_views + simplify
tensor_map = graph_rewrite_map(big_sink, multi_pm+replace_allreduce+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")

View file

@ -910,7 +910,7 @@ def track_rewrites(named=False, name_fxn:Callable|None=None):
tracked_keys.append(f"{func.__name__}_{_name_cnt[func.__name__]}" if count_names else args[0])
tracked_ctxs.append([])
ret = func(*args, **kwargs)
if TRACK_MATCH_STATS >= 2 and name_fxn is not None: tracked_keys[-1] = f"{name_fxn(ret)} n{_name_cnt[func.__name__]}"
if TRACK_MATCH_STATS >= 2 and name_fxn is not None: tracked_keys[-1] = f"{name_fxn(*args, **kwargs, ret=ret)} n{_name_cnt[func.__name__]}"
if getenv("CAPTURE_PROCESS_REPLAY"):
# find the unittest frame we're capturing in
frm = sys._getframe(1)