mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
viz: view source working even if compile failed (#16657)
* failing test * hard * ret_dict * switch to _data for tests too * update sqtt * start work * Ops.LINEAR looks good * baseline with depth works * support depth * types * @needs_tracked_pm * update, marg can error too * unwrap_or goes to many more places * move things to soft_err * soft_err everywhere needed * diff cleanup * use list * rewrite it * change * update depth number * small comment change
This commit is contained in:
parent
31094a794f
commit
b753fb5e4c
3 changed files with 73 additions and 21 deletions
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
|
|
@ -224,7 +224,9 @@ jobs:
|
|||
- name: Run NULL backend tests
|
||||
run: DEV=NULL python -m pytest -n=auto test/null/ --durations=20
|
||||
- name: Run targeted tests on NULL backend
|
||||
run: DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step
|
||||
run: |
|
||||
DEV=NULL python3 -m unittest test.backend.test_multitensor.TestMultiTensor.test_data_parallel_resnet_train_step
|
||||
DEV=NULL VIZ=1 python3 -m pytest -n=auto test/null/test_viz.py
|
||||
# TODO: too slow
|
||||
# - name: Run SDXL on NULL backend
|
||||
# run: DEV=NULL DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatch
|
|||
from tinygrad.uop.symbolic import sym
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.helpers import colored, ansistrip, flatten, TracingKey, ProfileRangeEvent, ProfileEvent, Context, cpu_events, profile_marker
|
||||
from tinygrad.helpers import cpu_profile, ProfilePointEvent, unwrap
|
||||
from tinygrad.helpers import cpu_profile, ProfilePointEvent, unwrap, VIZ
|
||||
from tinygrad.device import Buffer
|
||||
|
||||
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, active_group, _name_cnt, RewriteTrace
|
||||
|
|
@ -47,6 +47,8 @@ def save_viz():
|
|||
yield viz
|
||||
viz.set_data()
|
||||
|
||||
needs_tracked_pm = unittest.skipUnless(VIZ, "using TrackedPatternMatcher requires global VIZ=1")
|
||||
|
||||
class TestViz(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
with save_viz() as viz:
|
||||
|
|
@ -462,6 +464,35 @@ class TestVizIntegration(unittest.TestCase):
|
|||
self.assertGreater(len(events), 0)
|
||||
self.assertEqual([e["st"] for e in events], [graph_st+i*events[0]["dur"] for i in range(len(events))])
|
||||
|
||||
@needs_tracked_pm
|
||||
def test_view_source(self):
|
||||
def custom_fn(X:UOp):
|
||||
X = X.flatten()
|
||||
i = UOp.range(X.numel(), 0)
|
||||
custom_op = UOp(Ops.CUSTOMI, src=(X[i],), arg="{} + undeclared_name")
|
||||
return X[i].store(custom_op).end(i).sink(arg=KernelInfo(name=f"custom_fn_{X.numel()}"))
|
||||
x = Tensor.custom_kernel(Tensor.empty(1, device="CPU"), fxn=custom_fn)[0]
|
||||
with save_viz() as viz:
|
||||
with self.assertRaises(Exception) as e:
|
||||
x.realize()
|
||||
lst = viz.list_items()
|
||||
codegen_idx = len(lst)-1
|
||||
steps = lst[codegen_idx]["steps"]
|
||||
lin_idx = next((i for i,s in enumerate(steps) if s["name"] == "View UOp List"), None)
|
||||
src_idx = next((i for i,s in enumerate(steps) if s["name"] == "View Source"), None)
|
||||
bin_idx = next((i for i,s in enumerate(steps) if s["name"] == "View Disassembly"), None)
|
||||
assert all(i is not None for i in [lin_idx, src_idx, bin_idx]), f"linear, source and disasm must be visible in {steps}"
|
||||
# Ops.LINEAR renders
|
||||
lin_render = get_render(viz.data, steps[lin_idx]["query"])["src"]
|
||||
self.assertIn("Ops.SINK", lin_render)
|
||||
self.assertIn("Ops.CUSTOMI", lin_render)
|
||||
# Ops.SOURCE renders
|
||||
src_render = get_render(viz.data, steps[src_idx]["query"])["src"]
|
||||
self.assertIn("undeclared_name", src_render)
|
||||
# Ops.BINARY shows the error message since compile failed
|
||||
bin_render = get_render(viz.data, steps[bin_idx]["query"])["src"]
|
||||
self.assertIn(type(e.exception).__name__, bin_render)
|
||||
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry
|
||||
from tinygrad.viz.serve import get_profile
|
||||
from tinygrad.viz.cli import decode_profile
|
||||
|
|
@ -763,6 +794,7 @@ from tinygrad.runtime.autogen.amd.rdna3.ins import (s_add_u32, s_branch, s_cbran
|
|||
s_cmp_eq_u64, s_code_end, s_endpgm, s_mov_b32, s_nop)
|
||||
from extra.gemm.amd_asm_matmul import Kernel
|
||||
|
||||
@needs_tracked_pm
|
||||
class TestCfg(unittest.TestCase):
|
||||
def setUp(self): self.arch = "gfx1100"
|
||||
|
||||
|
|
@ -961,6 +993,7 @@ def write_files(viz) -> list[str]:
|
|||
yield ["--rewrites-path", str(r), "--profile-path", str(p)]
|
||||
|
||||
class TestCLI(unittest.TestCase):
|
||||
@needs_tracked_pm
|
||||
def test_reconstruct_debug(self):
|
||||
with save_viz() as viz:
|
||||
Tensor.empty(1, device="NULL").add(2.0).realize()
|
||||
|
|
@ -1023,6 +1056,7 @@ class TestCLI(unittest.TestCase):
|
|||
self.assertEqual(len([s for s in select if s.get("value")]), 1, "debug output was not deduped")
|
||||
self.assertEqual(len([s for s in select if s.get("device") == "NULL"]), CNT, f"expected 4 runs for {name}")
|
||||
|
||||
@needs_tracked_pm
|
||||
def test_call_graph(self):
|
||||
@function(precompile=True)
|
||||
def f(x):
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class HTTPRequestHandler(BaseHTTPRequestHandler):
|
|||
self.send_header("Cache-Control", "no-cache")
|
||||
self.end_headers()
|
||||
for r in source:
|
||||
self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
|
||||
self.wfile.write(f"data: {json.dumps(filter_keys(r))}\n\n".encode("utf-8"))
|
||||
self.wfile.flush()
|
||||
self.wfile.write("data: [DONE]\n\n".encode("utf-8"))
|
||||
# pass if client closed connection
|
||||
|
|
@ -85,11 +85,11 @@ def load_rewrites(data:VizData) -> None:
|
|||
steps.append(create_step(s.name, ("/graph-rewrites", i, j), loc=s.loc, match_count=len(s.matches), code_line=printable(s.loc),
|
||||
trace=k.tb if j==0 else None, depth=s.depth))
|
||||
# get source and binary from Ops.PROGRAM
|
||||
if s.name == "View Program":
|
||||
ki = (p:=_reconstruct(data, s.sink, depth=1)).src[0].arg
|
||||
steps.append(create_step("View UOp List", ("/uops", i, len(steps))))
|
||||
steps.append(create_step("View Source", ("/code", i, len(steps)), p.src[3].arg))
|
||||
steps.append(create_step("View Disassembly", ("/asm", i, len(steps)), (k.ret, p.src[4].arg)))
|
||||
if s.name == "linearize/render":
|
||||
steps.append(create_step("View UOp List", ("/uops", i, len(steps)), j, depth=s.depth))
|
||||
steps.append(create_step("View Source", ("/code", i, len(steps)), j, depth=s.depth))
|
||||
steps.append(create_step("View Disassembly", ("/asm", i, len(steps)), (k.ret, j), depth=s.depth))
|
||||
if s.name == "View Program": ki = _reconstruct(data, s.sink, depth=1).src[0].arg
|
||||
for key in k.keys: data.ref_map[canonicalize_ast(key) if isinstance(key, UOp) else key] = i
|
||||
data.ctxs.append({"name":k.display_name, "steps":steps, "ki":ki})
|
||||
|
||||
|
|
@ -101,6 +101,7 @@ class GraphRewriteDetails(TypedDict):
|
|||
diff: list[str]|None # diff of the single UOp that changed
|
||||
change: list[int]|None # the new UOp id + all its parents ids
|
||||
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
|
||||
_sink: UOp
|
||||
|
||||
def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")"
|
||||
def mask_to_str(s:tuple[tuple[sint, sint], ...]): return "(" + ','.join(shape_to_str(x) for x in s) + ")"
|
||||
|
|
@ -127,7 +128,8 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
|||
if u.op in {*GroupOp.Movement, Ops.PARAM}: excluded.update(s for s in u.src if s.op is Ops.STACK and all(x.op is Ops.CONST for x in s.src))
|
||||
for u in toposort:
|
||||
argst = codecs.decode(str(u.arg), "unicode_escape")
|
||||
if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg)
|
||||
with soft_err():
|
||||
if u.op in GroupOp.Movement and u.marg: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg)
|
||||
if u.op is Ops.BINARY: argst = f"<{len(u.arg)} bytes>"
|
||||
if u.op is Ops.CONST and dtypes.is_float(u.dtype): argst = f"{u.arg:g}"
|
||||
wrap_len = 200 if u.op is Ops.SOURCE else 80
|
||||
|
|
@ -163,9 +165,10 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
|||
# limit SOURCE labels line count
|
||||
if u.op is Ops.SOURCE and len(lines:=label.split("\n")) > 40:
|
||||
label = "\n".join(lines[:30]) + "\n..."
|
||||
addrspace_color:str|None = None
|
||||
with soft_err(): addrspace_color = addrspace_colors.get(u.addrspace, None) if u.addrspace is not None else None
|
||||
graph[id(u)] = {"label":label, "src":[(i,id(x)) for i,x in enumerate(u.src)], "exclude":u in excluded, "color":uops_colors.get(u.op, "#ffffff"),
|
||||
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None,
|
||||
"addrspace":addrspace_colors.get(u.addrspace, None) if u.addrspace is not None else None}
|
||||
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None, "addrspace":addrspace_color}
|
||||
return graph
|
||||
|
||||
def _reconstruct(data:VizData, a:int, depth:int|None=None):
|
||||
|
|
@ -176,19 +179,24 @@ def _reconstruct(data:VizData, a:int, depth:int|None=None):
|
|||
if depth is None: data.all_uops[a] = ret
|
||||
return ret
|
||||
|
||||
def get_full_rewrite(data:VizData, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
|
||||
next_sink = _reconstruct(data, ctx.sink)
|
||||
yield {"graph":uop_to_json(data, next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
|
||||
def get_full_rewrite(data:VizData, ctx:TrackedGraphRewrite, depth:int|None=None) -> Generator[GraphRewriteDetails, None, None]:
|
||||
next_sink = _reconstruct(data, ctx.sink, depth=depth)
|
||||
yield {"graph":uop_to_json(data, next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None, "_sink":next_sink}
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches, disable=not ctx.matches):
|
||||
replaces[u0:=_reconstruct(data, u0_num)] = u1 = _reconstruct(data, u1_num)
|
||||
replaces[u0:=_reconstruct(data, u0_num, depth=depth)] = u1 = _reconstruct(data, u1_num, depth=depth)
|
||||
try: new_sink = next_sink.substitute(replaces)
|
||||
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
|
||||
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
|
||||
yield {"graph":(sink_json:=uop_to_json(data, new_sink)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
||||
"diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)}
|
||||
"diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr), "_sink":new_sink}
|
||||
if not ctx.bottom_up: next_sink = new_sink
|
||||
|
||||
def get_sink_at(upats:tuple[str, ...], viz_data:VizData, ctx:TrackedGraphRewrite, depth:int|None=None) -> UOp|None:
|
||||
for s in get_full_rewrite(viz_data, ctx, depth=depth):
|
||||
if s["upat"] is not None and any(n in s["upat"][1] for n in upats): return s["_sink"]
|
||||
return None
|
||||
|
||||
# encoder helpers
|
||||
|
||||
def enum_str(s, cache:dict[str, int]) -> int:
|
||||
|
|
@ -283,9 +291,10 @@ def graph_layout(k:str, dev_events:list[tuple[int, int, float, DevEvent]], start
|
|||
# by default, VIZ does not start when there is an error
|
||||
# use this to instead display the traceback to the user
|
||||
@contextmanager
|
||||
def soft_err(fn:Callable):
|
||||
def soft_err(fn:Callable|None=None):
|
||||
try: yield
|
||||
except Exception: fn({"src":traceback.format_exc()})
|
||||
except Exception:
|
||||
if fn is not None: fn({"src":traceback.format_exc()})
|
||||
|
||||
def row_tuple(row:str) -> tuple[tuple[int, int], ...]:
|
||||
return ((0, 0),) if "Clock" in row else tuple((ord(ss[0][0]), int(ss[1])) if len(ss:=x.split(":"))>1 else (999,999) for x in row.split())
|
||||
|
|
@ -609,11 +618,18 @@ def get_render(viz_data:VizData, query:str) -> dict:
|
|||
i, j, fmt = get_int(qs:=parse_qs(url.query), "ctx"), get_int(qs, "step"), url.path.lstrip("/")
|
||||
data = viz_data.ctxs[i]["steps"][j]["_data"]
|
||||
if fmt == "graph-rewrites": return {"value":get_full_rewrite(viz_data, viz_data.trace.rewrites[i][j]), "content_type":"text/event-stream"}
|
||||
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(_reconstruct(viz_data, viz_data.trace.rewrites[i][j-1].sink).src[2].src))}
|
||||
if fmt == "code": return {"src":data, "lang":"cpp"}
|
||||
if fmt == "uops":
|
||||
if (sink:=get_sink_at(("do_linearize",), viz_data, viz_data.trace.rewrites[i][data])) is None: return {"src":"No linear found"}
|
||||
return {"src":sink.arg} if sink.op is Ops.REWRITE_ERROR else {"src":get_stdout(lambda: print_uops(list(unwrap(sink).src[2].src)))}
|
||||
if fmt == "code":
|
||||
if (sink:=get_sink_at(("do_render",), viz_data, viz_data.trace.rewrites[i][data], depth=1)) is None: return {"src":"No source found"}
|
||||
return {"src":sink.arg} if sink.op is Ops.REWRITE_ERROR else {"src":sink.src[3].arg, "lang":"cpp"}
|
||||
if fmt == "asm":
|
||||
ret:dict = {}
|
||||
renderer, lib = data
|
||||
renderer, idx = data
|
||||
if (sink:=get_sink_at(("do_compile","do_assemble"), viz_data, viz_data.trace.rewrites[i][idx], depth=1)) is None: return {"src":"No binary found"}
|
||||
if sink.op is Ops.REWRITE_ERROR: return {"src":sink.arg}
|
||||
lib:bytes = sink.src[4].arg
|
||||
if renderer.target.arch.startswith("gfx"):
|
||||
with soft_err(lambda err: ret.update(err)): ret.update(amdgpu_cfg(lib, renderer.target.arch))
|
||||
else: ret["src"] = get_stdout(lambda: renderer.compiler.disassemble(lib))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue