viz/cli: schedule renderer (#16101)

* simpler steps

* work

* work

* iterate

* faster

* better

* simplify more

* sys stdin

* less

* work

* work and mv

* better

* seen bufs

* all call graphs

* print query

* ux

* param to buffer / buffer_view

* work

* respect NO_COLOR in uop_to_json

* less

* render uops

* rm custom renderer

* call can't pyrender.

* unrelated diff

* assert

* 5
This commit is contained in:
qazal 2026-05-10 19:56:16 +03:00 committed by GitHub
commit 2dd84416bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 34 additions and 11 deletions

View file

@ -320,7 +320,7 @@ class TestVizGC(unittest.TestCase):
# VIZ integrates with other parts of tinygrad
from tinygrad import Tensor, Device, TinyJit, Variable
from tinygrad import Tensor, Device, TinyJit, Variable, function
class TestVizIntegration(unittest.TestCase):
# codegen supports rendering of code blocks
@ -340,15 +340,18 @@ class TestVizIntegration(unittest.TestCase):
c1 = Tensor.empty(4).add(1)
c2 = Tensor.empty(8).add(1)
sched = c1.schedule_linear(c2)
prgs = [to_program(si.src[0], Device[Device.DEFAULT].renderer).arg.name for si in sched.src]
with Context(NO_COLOR=0):
prgs = [to_program(si.src[0], Device[Device.DEFAULT].renderer).arg.name for si in sched.src]
lst = viz.list_items()
sched_idx = next(i for i,l in enumerate(lst) if l["name"].startswith("Schedule"))
viz_kernel = next(i for i,s in enumerate(lst[sched_idx]["steps"]) if s["name"] == "View Kernel Graph")
graph = next(viz.get_details(sched_idx, viz_kernel))["graph"]
with Context(NO_COLOR=1):
graph = next(viz.get_details(sched_idx, viz_kernel))["graph"]
call_nodes = [n for n in graph.values() if n["label"].startswith("CALL")]
for i,n in enumerate(call_nodes):
assert n["ref"] is not None
self.assertEqual(lst[n["ref"]]["name"], prgs[i])
assert ansistrip(prgs[i]) in n["label"], f"CALL must contain kernel name, got {n['label']}"
@Context(TRACEMETA=2)
def test_metadata_tracing(self):
@ -980,5 +983,23 @@ class TestCLI(unittest.TestCase):
self.assertEqual(len([s for s in select if s.get("value")]), 1, "debug output was not deduped")
self.assertEqual(len([s for s in select if s.get("device") == "NULL"]), CNT, f"expected 4 runs for {name}")
def test_call_graph(self):
@function(precompile=True)
def f(x):
r = x.sum(axis=1).reshape(32, 1).expand(32, 32).contiguous()
return x + r
# turn of scache because this test requires a complete schedule rewrite
with save_viz() as viz, Context(SCACHE=0):
f(f(Tensor.empty(32, 32, device="NULL"))).realize()
with write_files(viz) as files, Context(NO_COLOR=1):
prgs = [s["name"] for s in run_cli(*files, "-s", "NULL")]
with Context(DEBUG=5):
out = run_cli(*files, "-s", "TINY")
i = next(i for i,s in enumerate(out) if s.get("value", "").lstrip() == "View Kernel Graph")
# next print is the CALL graph, CLI outputs exactly as web in TestVizIntegration.test_link_sched_codegen
call_nodes = [n for n in out[i+1].values() if n["label"].startswith("CALL")]
for i,n in enumerate(call_nodes):
assert prgs[i] in n["label"], f"CALL must contain kernel name, got {n['label']}"
if __name__ == "__main__":
unittest.main()

View file

@ -4,6 +4,7 @@ os.environ["VIZ"] = "0"
if hasattr(signal, "SIGPIPE"): signal.signal(signal.SIGPIPE, signal.SIG_DFL)
from typing import Iterator
from tinygrad.viz import serve as viz
from tinygrad.viz.serve import fmt_colored
from tinygrad.uop.ops import RewriteTrace
from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen, ProfilePointEvent, ProfileRangeEvent, TracingKey, unwrap, NO_COLOR, DEBUG
@ -45,8 +46,6 @@ def decode_profile(data:bytes) -> dict:
for k,rep,num,mode in [u("<IIIB") for _ in range(u("<I")[0])]]}})
return {"dur":total_dur, "peak":global_peak, "layout":layout, "markers":markers}
def fmt_colored(s:str) -> str: return ansistrip(s) if NO_COLOR else s
def to_str(k:str, v) -> str:
if k == "FLOPS" or k.startswith("B/s"): return f"{v*1e-9:.0f} G{k}" if v < 1e13 else f"{v*1e-12:.0f} T{k}"
if k == "B": return next((f"{v/s:.0f} {u}" for s,u in ((1e9,"GB"),(1e6,"MB"),(1e3,"KB")) if v>=s), f"{v:.0f} B")
@ -65,11 +64,11 @@ def main(args) -> None:
def emit(val, to_str=str) -> str: return json.dumps(val if isinstance(val, dict) else {"value":val}) if args.json else to_str(val)
def print_step(step:dict, reconstruct_matches=False) -> None:
def print_step(step:dict, print_graph=False, reconstruct_matches=False) -> None:
data = viz.get_render(viz_data, step["query"])
if isinstance(data.get("value"), Iterator):
for m in data["value"]:
if m.get("uop"): print(emit(m["uop"]))
if "uop" in m: print(emit(m["graph"] if print_graph else m["uop"]))
if not reconstruct_matches: return None
if m.get("diff"):
loc = pathlib.Path(m["upat"][0][0])
@ -191,11 +190,11 @@ def main(args) -> None:
print(emit(k, to_str=fmt_row))
if k["ref"] is not None and k["ref"] not in seen_refs:
seen_refs.add(k["ref"])
for s in viz_data.ctxs[k["ref"]]["steps"]:
for i,s in enumerate(viz_data.ctxs[k["ref"]]["steps"]):
if DEBUG >= 3 and s["name"] == "View Base AST": print_step(s)
if DEBUG >= 4 and s["name"] == "View Source": print_step(s)
if DEBUG >= 5 or ls: print(emit(" "*s["depth"]+s["name"]+(f" - {s['match_count']}" if s.get('match_count', 0) else '')))
if DEBUG >= 6: print_step(s)
if DEBUG >= 6 or (DEBUG >= 5 and s["name"] == "View Kernel Graph"): print_step(s, print_graph=True)
if DEBUG >= 7 or s["name"] in args.src: print_step(s, reconstruct_matches=True)
elif DEBUG >= 3 and k.get("ext"): print(emit(k["ext"]))
for k in (produce_top_kernels if args.t else produce_all_kernels)(): render_event(k)

View file

@ -8,7 +8,7 @@ from urllib.parse import parse_qs, urlparse
from http.server import BaseHTTPRequestHandler
from typing import Any, TypedDict, TypeVar, Generator, Callable
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
from tinygrad.helpers import printable, Context, START_TIME
from tinygrad.helpers import printable, Context, START_TIME, NO_COLOR, ansistrip
from tinygrad.renderer.amd.dsl import Inst
from tinygrad.renderer.amd import detect_format
@ -105,6 +105,8 @@ def pystr(u:UOp) -> str:
try: return pyrender(u)
except Exception: return str(u)
def fmt_colored(s:str) -> str: return ansistrip(s) if NO_COLOR else s
def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
assert isinstance(x, UOp)
graph: dict[int, dict] = {}
@ -148,7 +150,8 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
label += "\n"+' '.join([f"{range_str(s, color=True)}({s.vmax+1})" for s in trngs])
except Exception:
label += "\n<ISSUE GETTING LABEL>"
if (ref:=data.ref_map.get(u.src[0]) if u.op in {Ops.CALL, Ops.FUNCTION} else None) is not None: label += f"\ncodegen@{data.ctxs[ref]['name']}"
ref = data.ref_map.get(u.src[0]) if u.op in {Ops.CALL, Ops.FUNCTION} else None
if ref is not None: label += f"\ncodegen@{fmt_colored(data.ctxs[ref]['name'])}"
# NOTE: kernel already has metadata in arg
if TRACEMETA >= 2 and u.metadata is not None and u.op not in {Ops.CALL, Ops.FUNCTION}: label += "\n"+str(u.metadata)
# limit SOURCE labels line count