viz: only store kernel info (#16641)

This commit is contained in:
qazal 2026-06-17 15:21:57 +08:00 committed by GitHub
commit c7055d658f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -80,18 +80,18 @@ def load_rewrites(data:VizData) -> None:
assert not data.ctxs and not data.ref_map, "load_rewrites called multiple times"
for i,k in enumerate(data.trace.keys):
steps:list[dict] = []
p:UOp|None = None
ki:KernelInfo|None = None
for j,s in enumerate(data.trace.rewrites[i]):
steps.append(create_step(s.name, ("/graph-rewrites", i, j), loc=s.loc, match_count=len(s.matches), code_line=printable(s.loc),
trace=k.tb if j==0 else None, depth=s.depth))
# get source and binary from Ops.PROGRAM
if s.name == "View Program":
p = _reconstruct(data, s.sink, depth=1)
ki = (p:=_reconstruct(data, s.sink, depth=1)).src[0].arg
steps.append(create_step("View UOp List", ("/uops", i, len(steps))))
steps.append(create_step("View Source", ("/code", i, len(steps)), p.src[3].arg))
steps.append(create_step("View Disassembly", ("/asm", i, len(steps)), (k.ret, p.src[4].arg)))
for key in k.keys: data.ref_map[canonicalize_ast(key) if isinstance(key, UOp) else key] = i
data.ctxs.append({"name":k.display_name, "steps":steps, "prg":p})
data.ctxs.append({"name":k.display_name, "steps":steps, "ki":ki})
# ** get the complete UOp graphs for one rewrite
@ -229,7 +229,7 @@ def timeline_layout(data:VizData, dev_events:list[tuple[int, int, float, DevEven
fmt:dict = {}
if (ref:=data.ref_map.get(name)) is not None and ref < len(data.ctxs):
name = data.ctxs[ref]["name"]
if (p:=data.ctxs[ref].get("prg")) is not None and (ki:=p.src[0].arg).estimates is not None and ei is not None:
if (ki:=data.ctxs[ref].get("ki")) is not None and ki.estimates is not None and ei is not None:
fmt["FLOPS"] = int(sym_infer(ki.estimates.ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6))
fmt["B/s mem"], fmt["B/s lds"] = int(sym_infer(ki.estimates.mem, var_vals)/t), int(sym_infer(ki.estimates.lds, var_vals)/t)
if ei.arg["metadata"]: fmt["metadata"] = ",".join([str(m) for m in ei.arg['metadata']+["batched" if isinstance(e,ProfileGraphEntry) else ""]])
@ -341,7 +341,7 @@ def load_amd_counters(data:VizData, profile:list) -> None:
run_number = {n:0 for n,_ in counter_events}
for (k, tag),v in counter_events.items():
# use the colored name if it exists
name = data.ctxs[r]["prg"].src[0].arg.name if (r:=data.ref_map.get(pname:=prg_events[k].name)) is not None else pname
name = data.ctxs[r]["ki"].name if (r:=data.ref_map.get(pname:=prg_events[k].name)) is not None else pname
run_number[k] += 1
steps:list[dict] = []
if (pmc:=v.get("ProfilePMCEvent")):