viz: view source working even if compile failed (#16657)

* failing test

* hard

* ret_dict

* switch to _data for tests too

* update sqtt

* start work

* Ops.LINEAR looks good

* baseline with depth works

* support depth

* types

* @needs_tracked_pm

* update, marg can error too

* unwrap_or goes to many more places

* move things to soft_err

* soft_err everywhere needed

* diff cleanup

* use list

* rewrite it

* change

* update depth number

* small comment change
This commit is contained in:
qazal 2026-06-18 16:34:53 +08:00 committed by GitHub
commit b753fb5e4c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 73 additions and 21 deletions

View file

@ -224,7 +224,9 @@ jobs:
- name: Run NULL backend tests
run: DEV=NULL python -m pytest -n=auto test/null/ --durations=20
- name: Run targeted tests on NULL backend
run: DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step
run: |
DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step
DEV=NULL VIZ=1 python3 -m pytest -n=auto test/null/test_viz.py
# TODO: too slow
# - name: Run SDXL on NULL backend
# run: DEV=NULL DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights

View file

@ -7,7 +7,7 @@ from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatch
from tinygrad.uop.symbolic import sym
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.helpers import colored, ansistrip, flatten, TracingKey, ProfileRangeEvent, ProfileEvent, Context, cpu_events, profile_marker
from tinygrad.helpers import cpu_profile, ProfilePointEvent, unwrap
from tinygrad.helpers import cpu_profile, ProfilePointEvent, unwrap, VIZ
from tinygrad.device import Buffer
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, active_group, _name_cnt, RewriteTrace
@ -47,6 +47,8 @@ def save_viz():
yield viz
viz.set_data()
needs_tracked_pm = unittest.skipUnless(VIZ, "using TrackedPatternMatcher requires global VIZ=1")
class TestViz(unittest.TestCase):
def test_simple(self):
with save_viz() as viz:
@ -462,6 +464,35 @@ class TestVizIntegration(unittest.TestCase):
self.assertGreater(len(events), 0)
self.assertEqual([e["st"] for e in events], [graph_st+i*events[0]["dur"] for i in range(len(events))])
@needs_tracked_pm
def test_view_source(self):
def custom_fn(X:UOp):
X = X.flatten()
i = UOp.range(X.numel(), 0)
custom_op = UOp(Ops.CUSTOMI, src=(X[i],), arg="{} + undeclared_name")
return X[i].store(custom_op).end(i).sink(arg=KernelInfo(name=f"custom_fn_{X.numel()}"))
x = Tensor.custom_kernel(Tensor.empty(1, device="CPU"), fxn=custom_fn)[0]
with save_viz() as viz:
with self.assertRaises(Exception) as e:
x.realize()
lst = viz.list_items()
codegen_idx = len(lst)-1
steps = lst[codegen_idx]["steps"]
lin_idx = next((i for i,s in enumerate(steps) if s["name"] == "View UOp List"), None)
src_idx = next((i for i,s in enumerate(steps) if s["name"] == "View Source"), None)
bin_idx = next((i for i,s in enumerate(steps) if s["name"] == "View Disassembly"), None)
assert all(i is not None for i in [lin_idx, src_idx, bin_idx]), f"linear, source and disasm must be visible in {steps}"
# Ops.LINEAR renders
lin_render = get_render(viz.data, steps[lin_idx]["query"])["src"]
self.assertIn("Ops.SINK", lin_render)
self.assertIn("Ops.CUSTOMI", lin_render)
# Ops.SOURCE renders
src_render = get_render(viz.data, steps[src_idx]["query"])["src"]
self.assertIn("undeclared_name", src_render)
# Ops.BINARY shows the error message since compile failed
bin_render = get_render(viz.data, steps[bin_idx]["query"])["src"]
self.assertIn(type(e.exception).__name__, bin_render)
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry
from tinygrad.viz.serve import get_profile
from tinygrad.viz.cli import decode_profile
@ -763,6 +794,7 @@ from tinygrad.runtime.autogen.amd.rdna3.ins import (s_add_u32, s_branch, s_cbran
s_cmp_eq_u64, s_code_end, s_endpgm, s_mov_b32, s_nop)
from extra.gemm.amd_asm_matmul import Kernel
@needs_tracked_pm
class TestCfg(unittest.TestCase):
def setUp(self): self.arch = "gfx1100"
@ -961,6 +993,7 @@ def write_files(viz) -> list[str]:
yield ["--rewrites-path", str(r), "--profile-path", str(p)]
class TestCLI(unittest.TestCase):
@needs_tracked_pm
def test_reconstruct_debug(self):
with save_viz() as viz:
Tensor.empty(1, device="NULL").add(2.0).realize()
@ -1023,6 +1056,7 @@ class TestCLI(unittest.TestCase):
self.assertEqual(len([s for s in select if s.get("value")]), 1, "debug output was not deduped")
self.assertEqual(len([s for s in select if s.get("device") == "NULL"]), CNT, f"expected 4 runs for {name}")
@needs_tracked_pm
def test_call_graph(self):
@function(precompile=True)
def f(x):

View file

@ -33,7 +33,7 @@ class HTTPRequestHandler(BaseHTTPRequestHandler):
self.send_header("Cache-Control", "no-cache")
self.end_headers()
for r in source:
self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
self.wfile.write(f"data: {json.dumps(filter_keys(r))}\n\n".encode("utf-8"))
self.wfile.flush()
self.wfile.write("data: [DONE]\n\n".encode("utf-8"))
# pass if client closed connection
@ -85,11 +85,11 @@ def load_rewrites(data:VizData) -> None:
steps.append(create_step(s.name, ("/graph-rewrites", i, j), loc=s.loc, match_count=len(s.matches), code_line=printable(s.loc),
trace=k.tb if j==0 else None, depth=s.depth))
# get source and binary from Ops.PROGRAM
if s.name == "View Program":
ki = (p:=_reconstruct(data, s.sink, depth=1)).src[0].arg
steps.append(create_step("View UOp List", ("/uops", i, len(steps))))
steps.append(create_step("View Source", ("/code", i, len(steps)), p.src[3].arg))
steps.append(create_step("View Disassembly", ("/asm", i, len(steps)), (k.ret, p.src[4].arg)))
if s.name == "linearize/render":
steps.append(create_step("View UOp List", ("/uops", i, len(steps)), j, depth=s.depth))
steps.append(create_step("View Source", ("/code", i, len(steps)), j, depth=s.depth))
steps.append(create_step("View Disassembly", ("/asm", i, len(steps)), (k.ret, j), depth=s.depth))
if s.name == "View Program": ki = _reconstruct(data, s.sink, depth=1).src[0].arg
for key in k.keys: data.ref_map[canonicalize_ast(key) if isinstance(key, UOp) else key] = i
data.ctxs.append({"name":k.display_name, "steps":steps, "ki":ki})
@ -101,6 +101,7 @@ class GraphRewriteDetails(TypedDict):
diff: list[str]|None # diff of the single UOp that changed
change: list[int]|None # the new UOp id + all its parents ids
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
_sink: UOp
def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")"
def mask_to_str(s:tuple[tuple[sint, sint], ...]): return "(" + ','.join(shape_to_str(x) for x in s) + ")"
@ -127,7 +128,8 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
if u.op in {*GroupOp.Movement, Ops.PARAM}: excluded.update(s for s in u.src if s.op is Ops.STACK and all(x.op is Ops.CONST for x in s.src))
for u in toposort:
argst = codecs.decode(str(u.arg), "unicode_escape")
if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg)
with soft_err():
if u.op in GroupOp.Movement and u.marg: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg)
if u.op is Ops.BINARY: argst = f"<{len(u.arg)} bytes>"
if u.op is Ops.CONST and dtypes.is_float(u.dtype): argst = f"{u.arg:g}"
wrap_len = 200 if u.op is Ops.SOURCE else 80
@ -163,9 +165,10 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
# limit SOURCE labels line count
if u.op is Ops.SOURCE and len(lines:=label.split("\n")) > 40:
label = "\n".join(lines[:30]) + "\n..."
addrspace_color:str|None = None
with soft_err(): addrspace_color = addrspace_colors.get(u.addrspace, None) if u.addrspace is not None else None
graph[id(u)] = {"label":label, "src":[(i,id(x)) for i,x in enumerate(u.src)], "exclude":u in excluded, "color":uops_colors.get(u.op, "#ffffff"),
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None,
"addrspace":addrspace_colors.get(u.addrspace, None) if u.addrspace is not None else None}
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None, "addrspace":addrspace_color}
return graph
def _reconstruct(data:VizData, a:int, depth:int|None=None):
@ -176,19 +179,24 @@ def _reconstruct(data:VizData, a:int, depth:int|None=None):
if depth is None: data.all_uops[a] = ret
return ret
def get_full_rewrite(data:VizData, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
next_sink = _reconstruct(data, ctx.sink)
yield {"graph":uop_to_json(data, next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
def get_full_rewrite(data:VizData, ctx:TrackedGraphRewrite, depth:int|None=None) -> Generator[GraphRewriteDetails, None, None]:
next_sink = _reconstruct(data, ctx.sink, depth=depth)
yield {"graph":uop_to_json(data, next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None, "_sink":next_sink}
replaces: dict[UOp, UOp] = {}
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches, disable=not ctx.matches):
replaces[u0:=_reconstruct(data, u0_num)] = u1 = _reconstruct(data, u1_num)
replaces[u0:=_reconstruct(data, u0_num, depth=depth)] = u1 = _reconstruct(data, u1_num, depth=depth)
try: new_sink = next_sink.substitute(replaces)
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
yield {"graph":(sink_json:=uop_to_json(data, new_sink)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json],
"diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)}
"diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr), "_sink":new_sink}
if not ctx.bottom_up: next_sink = new_sink
def get_sink_at(upats:tuple[str, ...], viz_data:VizData, ctx:TrackedGraphRewrite, depth:int|None=None) -> UOp|None:
for s in get_full_rewrite(viz_data, ctx, depth=depth):
if s["upat"] is not None and any(n in s["upat"][1] for n in upats): return s["_sink"]
return None
# encoder helpers
def enum_str(s, cache:dict[str, int]) -> int:
@ -283,9 +291,10 @@ def graph_layout(k:str, dev_events:list[tuple[int, int, float, DevEvent]], start
# by default, VIZ does not start when there is an error
# use this to instead display the traceback to the user
@contextmanager
def soft_err(fn:Callable):
def soft_err(fn:Callable|None=None):
try: yield
except Exception: fn({"src":traceback.format_exc()})
except Exception:
if fn is not None: fn({"src":traceback.format_exc()})
def row_tuple(row:str) -> tuple[tuple[int, int], ...]:
return ((0, 0),) if "Clock" in row else tuple((ord(ss[0][0]), int(ss[1])) if len(ss:=x.split(":"))>1 else (999,999) for x in row.split())
@ -609,11 +618,18 @@ def get_render(viz_data:VizData, query:str) -> dict:
i, j, fmt = get_int(qs:=parse_qs(url.query), "ctx"), get_int(qs, "step"), url.path.lstrip("/")
data = viz_data.ctxs[i]["steps"][j]["_data"]
if fmt == "graph-rewrites": return {"value":get_full_rewrite(viz_data, viz_data.trace.rewrites[i][j]), "content_type":"text/event-stream"}
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(_reconstruct(viz_data, viz_data.trace.rewrites[i][j-1].sink).src[2].src))}
if fmt == "code": return {"src":data, "lang":"cpp"}
if fmt == "uops":
if (sink:=get_sink_at(("do_linearize",), viz_data, viz_data.trace.rewrites[i][data])) is None: return {"src":"No linear found"}
return {"src":sink.arg} if sink.op is Ops.REWRITE_ERROR else {"src":get_stdout(lambda: print_uops(list(unwrap(sink).src[2].src)))}
if fmt == "code":
if (sink:=get_sink_at(("do_render",), viz_data, viz_data.trace.rewrites[i][data], depth=1)) is None: return {"src":"No source found"}
return {"src":sink.arg} if sink.op is Ops.REWRITE_ERROR else {"src":sink.src[3].arg, "lang":"cpp"}
if fmt == "asm":
ret:dict = {}
renderer, lib = data
renderer, idx = data
if (sink:=get_sink_at(("do_compile","do_assemble"), viz_data, viz_data.trace.rewrites[i][idx], depth=1)) is None: return {"src":"No binary found"}
if sink.op is Ops.REWRITE_ERROR: return {"src":sink.arg}
lib:bytes = sink.src[4].arg
if renderer.target.arch.startswith("gfx"):
with soft_err(lambda err: ret.update(err)): ret.update(amdgpu_cfg(lib, renderer.target.arch))
else: ret["src"] = get_stdout(lambda: renderer.compiler.disassemble(lib))