mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
inline viz get_name [pr] (#10682)
* inline viz get_name [pr] * changing name_fxn makes this simpler * waitUntil dom
This commit is contained in:
parent
86a19e19e8
commit
b515d796fb
4 changed files with 4 additions and 8 deletions
|
|
@ -93,7 +93,7 @@ class TestViz(unittest.TestCase):
|
|||
self.assertEqual(len(ret), 1)
|
||||
|
||||
def test_track_rewrites_name_fxn(self):
|
||||
@track_rewrites(name_fxn=lambda r: f"output_{r}")
|
||||
@track_rewrites(name_fxn=lambda _,ret: f"output_{ret}")
|
||||
def do_rewrite(x:UOp):
|
||||
x = graph_rewrite(x, symbolic)
|
||||
return x.render()
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ async function main() {
|
|||
try {
|
||||
browser = await puppeteer.launch({ headless: true });
|
||||
const page = await browser.newPage();
|
||||
const res = await page.goto("http://localhost:8000");
|
||||
const res = await page.goto("http://localhost:8000", { waitUntil:"domcontentloaded" });
|
||||
if (res.status() !== 200) throw new Error("Failed to load page");
|
||||
const scheduleSelector = await page.waitForSelector("ul");
|
||||
scheduleSelector.click();
|
||||
|
|
|
|||
|
|
@ -501,10 +501,6 @@ do_fuse = PatternMatcher([
|
|||
(UPat(Ops.REDUCE_AXIS, name="root"), fuse_arange),
|
||||
])
|
||||
|
||||
def get_name(becomes_map:dict[UOp, UOp]) -> str:
|
||||
assigned_kernels = {u.base.buf_uop:u.base.src[1] for u in becomes_map.values() if u.base.op is Ops.ASSIGN}.values()
|
||||
return f"Schedule {pluralize('Kernel', len(set(assigned_kernels)))}"
|
||||
|
||||
add_gbarrier = PatternMatcher([(UPat(GroupOp.All-{Ops.GBARRIER, Ops.ASSIGN}, name="x"),
|
||||
lambda ctx,x: x.replace(tag=1).gbarrier() if x in ctx and x.tag is None else None)])
|
||||
|
||||
|
|
@ -538,7 +534,7 @@ finalize_gbarrier = PatternMatcher([
|
|||
|
||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
@track_rewrites(name_fxn=get_name)
|
||||
@track_rewrites(name_fxn=lambda big_sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[big_sink].toposort() if u.op is Ops.KERNEL]))}")
|
||||
def get_kernelize_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
# multi + merge_views + simplify
|
||||
tensor_map = graph_rewrite_map(big_sink, multi_pm+replace_allreduce+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
|
|
|
|||
|
|
@ -910,7 +910,7 @@ def track_rewrites(named=False, name_fxn:Callable|None=None):
|
|||
tracked_keys.append(f"{func.__name__}_{_name_cnt[func.__name__]}" if count_names else args[0])
|
||||
tracked_ctxs.append([])
|
||||
ret = func(*args, **kwargs)
|
||||
if TRACK_MATCH_STATS >= 2 and name_fxn is not None: tracked_keys[-1] = f"{name_fxn(ret)} n{_name_cnt[func.__name__]}"
|
||||
if TRACK_MATCH_STATS >= 2 and name_fxn is not None: tracked_keys[-1] = f"{name_fxn(*args, **kwargs, ret=ret)} n{_name_cnt[func.__name__]}"
|
||||
if getenv("CAPTURE_PROCESS_REPLAY"):
|
||||
# find the unittest frame we're capturing in
|
||||
frm = sys._getframe(1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue