viz/cli: emit all runs of selected kernel, json fixes (#16124)

* keep print

* --json in tests, sqtt --json err

* work

* import

* less

* line
This commit is contained in:
qazal 2026-05-10 15:45:51 +03:00 committed by GitHub
commit 39ce780907
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 34 additions and 29 deletions

View file

@ -130,16 +130,14 @@ class TestSQTTMapBase(unittest.TestCase):
def test_sqtt_cli(self):
for pkl_path in sorted((EXAMPLES_DIR/self.target).glob("*.pkl")):
out = run_cli("--profile-path", str(pkl_path), "--ls")
sqtt_traces = [l.strip() for l in out.split("\n") if "SQTT" in l]
sqtt_traces = [l["value"].strip() for l in out if "SQTT" in l["value"]]
for name in sqtt_traces:
out = run_cli("--profile-path", str(pkl_path), "-s", ansistrip(name))
lines = out.split("\n")
self.assertIn("Clk", lines[0])
for r in lines[2:]:
parts = r.split()
self.assertTrue(parts[0].isdigit(), f"expected clock timestamp, got {parts[0]}")
lines = run_cli("--profile-path", str(pkl_path), "-s", ansistrip(name))
self.assertIn("Clk", lines[0]["value"])
waves = [r["clk"] for r in lines[2:] if "WAVE" in r["unit"]]
self.assertEqual(waves, sorted(waves), f"wave timestamps not monotonic in {name}")
with Context(DEBUG=2):
kernels = run_cli("--profile-path", str(pkl_path), "-s", "AMD").split("\n")
kernels = run_cli("--profile-path", str(pkl_path), "-s", "AMD")
self.assertEqual(len(kernels), len(self.examples[pkl_path.stem][1]))
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"

View file

@ -904,12 +904,12 @@ class TestCfg(unittest.TestCase):
self.get_cfg("jump_back_to_end", k)
# launch viz cli without subprocess
def run_cli(*cli_args) -> str:
def run_cli(*cli_args) -> list[dict]:
from tinygrad.viz.cli import main, get_arg_parser
args = get_arg_parser().parse_args(cli_args)
args = get_arg_parser().parse_args(cli_args+("--json",))
with contextlib.redirect_stdout(buf:=io.StringIO()):
main(args)
return buf.getvalue().strip()
return [json.loads(line) for line in buf.getvalue().strip().splitlines()]
@contextlib.contextmanager
def write_files(viz) -> list[str]:
@ -926,8 +926,8 @@ class TestCLI(unittest.TestCase):
Tensor.empty(1, device="NULL").add(3.0).realize()
with write_files(viz) as files, Context(DEBUG=4):
out = run_cli(*files, "-s", "NULL")
self.assertIn("void E", out)
self.assertIn("marker @ 1", out)
assert any(s.get("value", "").startswith("void E") for s in out)
assert any(s.get("name", "") == "marker @ 1" for s in out)
def test_aggregate(self):
N, CNT = 1024, 5
@ -937,7 +937,7 @@ class TestCLI(unittest.TestCase):
for _ in range(CNT):
(Tensor.empty(N, N, device="NULL").assign(Tensor.empty(N, N, device="NULL"))).realize()
with write_files(viz) as files, Context(NO_COLOR=1):
kernels = [json.loads(line) for line in run_cli(*files, "-s", "NULL", "-t", "--json").splitlines()]
kernels = run_cli(*files, "-s", "NULL", "-t")
self.assertEqual(len(kernels), 2)
gemm_summary = [s for s in kernels if s["name"].startswith("r_")][0]
copy_summary = [s for s in kernels if s["name"].startswith("E_")][0]
@ -956,8 +956,8 @@ class TestCLI(unittest.TestCase):
j = Variable("j", 1, 64).bind(j_val)
Tensor.realize(*f(a[:i], b[:j]))
with write_files(viz) as files:
out = [json.loads(line) for line in run_cli(*files, "-s", "NULL", "--json").splitlines()]
aggregate = [json.loads(line) for line in run_cli(*files, "-s", "NULL", "-t", "--json").splitlines()]
out = run_cli(*files, "-s", "NULL")
aggregate = run_cli(*files, "-s", "NULL", "-t")
self.assertEqual(len(out), 3*2)
# flops increases as N gets larger
gflops = [row["fmt"]["FLOPS"] for row in out]
@ -968,5 +968,17 @@ class TestCLI(unittest.TestCase):
agg_gflops = [row["fmt"]["FLOPS"] for row in aggregate]
assert all(min(gflops) < v < max(gflops) for v in agg_gflops), f"{agg_gflops}"
def test_dedup(self):
with save_viz() as viz:
for _ in range(CNT:=4):
Tensor.empty(4, device="NULL").add(1).realize()
Tensor.empty(8, device="NULL").add(1).realize()
with write_files(viz) as files, Context(NO_COLOR=1):
name = run_cli(*files, "-s", "NULL")[0]["name"]
with Context(DEBUG=3):
select = run_cli(*files, "-s", "NULL", name)
self.assertEqual(len([s for s in select if s.get("value")]), 1, "debug output was not deduped")
self.assertEqual(len([s for s in select if s.get("device") == "NULL"]), CNT, f"expected 4 runs for {name}")
if __name__ == "__main__":
unittest.main()

View file

@ -5,8 +5,7 @@ if hasattr(signal, "SIGPIPE"): signal.signal(signal.SIGPIPE, signal.SIG_DFL)
from typing import Iterator
from tinygrad.viz import serve as viz
from tinygrad.uop.ops import RewriteTrace
from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen, ProfilePointEvent, ProfileRangeEvent, TracingKey, unwrap, NO_COLOR
from tinygrad.helpers import DEBUG, Context
from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen, ProfilePointEvent, ProfileRangeEvent, TracingKey, unwrap, NO_COLOR, DEBUG
# profile decoder used in CLI and tests
def decode_profile(data:bytes) -> dict:
@ -83,15 +82,15 @@ def main(args) -> None:
profile = decode_profile(profile_bytes)
profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz_data.ctxs
if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))])
if args.list and not args.src: return print("ALL\n"+"\n".join(fmt_colored(k) for k in profile["layout"]))
if args.list and not args.src: return print("\n".join(emit(fmt_colored(k)) for k in ["ALL"]+list(profile["layout"])))
# ** SQTT printer
data = None if not args.src else get(profile["layout"], args.src[0])
if args.src and "SQTT" in args.src[0]:
# modern terminals support 24-bit color
def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m"
print(f"{'Clk':<12} {'Unit':<20} {'Op':<22} {'Dur':<4} {'Delay':<4} {'Info'}")
print("-" * 100)
print(emit(f"{'Clk':<12} {'Unit':<20} {'Op':<22} {'Dur':<4} {'Delay':<4} {'Info'}"))
print(emit("-" * 100))
pc_map:dict[int, str] = {}
pkt_idxs:dict[str, itertools.count] = {}
dispatch_to_inst:dict[str, tuple[str, int]] = {}
@ -188,6 +187,7 @@ def main(args) -> None:
fmt_row = fmt_top if args.t else fmt_all
seen_refs:set[int] = set()
def render_event(k:dict, ls=args.list) -> None:
if len(args.src) > 1 and ansistrip(k["name"]) not in args.src: return None
print(emit(k, to_str=fmt_row))
if k["ref"] is not None and k["ref"] not in seen_refs:
seen_refs.add(k["ref"])
@ -196,14 +196,9 @@ def main(args) -> None:
if DEBUG >= 4 and s["name"] == "View Source": print_step(s)
if DEBUG >= 5 or ls: print(emit(" "*s["depth"]+s["name"]+(f" - {s['match_count']}" if s.get('match_count', 0) else '')))
if DEBUG >= 6: print_step(s)
if DEBUG >= 7 or (len(args.src) > 2 and s["name"] == args.src[2]): print_step(s, reconstruct_matches=True)
if DEBUG >= 7 or s["name"] in args.src: print_step(s, reconstruct_matches=True)
elif DEBUG >= 3 and k.get("ext"): print(emit(k["ext"]))
produce = produce_top_kernels if args.t else produce_all_kernels
if len(args.src) > 1:
k = get({r["name"]:r for r in produce()}, args.src[1])
with Context(DEBUG=max(DEBUG.value, 3)): render_event(k, ls=True)
else:
for k in produce(): render_event(k)
for k in (produce_top_kernels if args.t else produce_all_kernels)(): render_event(k)
def get_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="python -m tinygrad.viz.cli")

View file

@ -342,7 +342,7 @@ def load_amd_counters(data:VizData, profile:list) -> None:
for e in sqtt:
if e.itrace: steps.append(create_step(f"SE:{e.se} PKTS", (f"/sqtt-{e.se}",len(data.ctxs),len(steps)), data=(e.blob,prg_events[k].lib,arch)))
try:
from extra.sqtt.roc import unpack_occ
with Context(DEBUG=0): from extra.sqtt.roc import unpack_occ
steps.append(create_step("OCC", ("/amd-sqtt-occ", len(data.ctxs), len(steps)),
data={"fxn":unpack_occ, "args":((k, tag), sqtt, prg_events[k], arch)}))
except Exception: pass