mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
viz/cli: emit all runs of selected kernel, json fixes (#16124)
* keep print * --json in tests, sqtt --json err * work * import * less * line
This commit is contained in:
parent
51c7dafb0d
commit
39ce780907
4 changed files with 34 additions and 29 deletions
|
|
@ -130,16 +130,14 @@ class TestSQTTMapBase(unittest.TestCase):
|
|||
def test_sqtt_cli(self):
|
||||
for pkl_path in sorted((EXAMPLES_DIR/self.target).glob("*.pkl")):
|
||||
out = run_cli("--profile-path", str(pkl_path), "--ls")
|
||||
sqtt_traces = [l.strip() for l in out.split("\n") if "SQTT" in l]
|
||||
sqtt_traces = [l["value"].strip() for l in out if "SQTT" in l["value"]]
|
||||
for name in sqtt_traces:
|
||||
out = run_cli("--profile-path", str(pkl_path), "-s", ansistrip(name))
|
||||
lines = out.split("\n")
|
||||
self.assertIn("Clk", lines[0])
|
||||
for r in lines[2:]:
|
||||
parts = r.split()
|
||||
self.assertTrue(parts[0].isdigit(), f"expected clock timestamp, got {parts[0]}")
|
||||
lines = run_cli("--profile-path", str(pkl_path), "-s", ansistrip(name))
|
||||
self.assertIn("Clk", lines[0]["value"])
|
||||
waves = [r["clk"] for r in lines[2:] if "WAVE" in r["unit"]]
|
||||
self.assertEqual(waves, sorted(waves), f"wave timestamps not monotonic in {name}")
|
||||
with Context(DEBUG=2):
|
||||
kernels = run_cli("--profile-path", str(pkl_path), "-s", "AMD").split("\n")
|
||||
kernels = run_cli("--profile-path", str(pkl_path), "-s", "AMD")
|
||||
self.assertEqual(len(kernels), len(self.examples[pkl_path.stem][1]))
|
||||
|
||||
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"
|
||||
|
|
|
|||
|
|
@ -904,12 +904,12 @@ class TestCfg(unittest.TestCase):
|
|||
self.get_cfg("jump_back_to_end", k)
|
||||
|
||||
# launch viz cli without subprocess
|
||||
def run_cli(*cli_args) -> str:
|
||||
def run_cli(*cli_args) -> list[dict]:
|
||||
from tinygrad.viz.cli import main, get_arg_parser
|
||||
args = get_arg_parser().parse_args(cli_args)
|
||||
args = get_arg_parser().parse_args(cli_args+("--json",))
|
||||
with contextlib.redirect_stdout(buf:=io.StringIO()):
|
||||
main(args)
|
||||
return buf.getvalue().strip()
|
||||
return [json.loads(line) for line in buf.getvalue().strip().splitlines()]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def write_files(viz) -> list[str]:
|
||||
|
|
@ -926,8 +926,8 @@ class TestCLI(unittest.TestCase):
|
|||
Tensor.empty(1, device="NULL").add(3.0).realize()
|
||||
with write_files(viz) as files, Context(DEBUG=4):
|
||||
out = run_cli(*files, "-s", "NULL")
|
||||
self.assertIn("void E", out)
|
||||
self.assertIn("marker @ 1", out)
|
||||
assert any(s.get("value", "").startswith("void E") for s in out)
|
||||
assert any(s.get("name", "") == "marker @ 1" for s in out)
|
||||
|
||||
def test_aggregate(self):
|
||||
N, CNT = 1024, 5
|
||||
|
|
@ -937,7 +937,7 @@ class TestCLI(unittest.TestCase):
|
|||
for _ in range(CNT):
|
||||
(Tensor.empty(N, N, device="NULL").assign(Tensor.empty(N, N, device="NULL"))).realize()
|
||||
with write_files(viz) as files, Context(NO_COLOR=1):
|
||||
kernels = [json.loads(line) for line in run_cli(*files, "-s", "NULL", "-t", "--json").splitlines()]
|
||||
kernels = run_cli(*files, "-s", "NULL", "-t")
|
||||
self.assertEqual(len(kernels), 2)
|
||||
gemm_summary = [s for s in kernels if s["name"].startswith("r_")][0]
|
||||
copy_summary = [s for s in kernels if s["name"].startswith("E_")][0]
|
||||
|
|
@ -956,8 +956,8 @@ class TestCLI(unittest.TestCase):
|
|||
j = Variable("j", 1, 64).bind(j_val)
|
||||
Tensor.realize(*f(a[:i], b[:j]))
|
||||
with write_files(viz) as files:
|
||||
out = [json.loads(line) for line in run_cli(*files, "-s", "NULL", "--json").splitlines()]
|
||||
aggregate = [json.loads(line) for line in run_cli(*files, "-s", "NULL", "-t", "--json").splitlines()]
|
||||
out = run_cli(*files, "-s", "NULL")
|
||||
aggregate = run_cli(*files, "-s", "NULL", "-t")
|
||||
self.assertEqual(len(out), 3*2)
|
||||
# flops increases as N gets larger
|
||||
gflops = [row["fmt"]["FLOPS"] for row in out]
|
||||
|
|
@ -968,5 +968,17 @@ class TestCLI(unittest.TestCase):
|
|||
agg_gflops = [row["fmt"]["FLOPS"] for row in aggregate]
|
||||
assert all(min(gflops) < v < max(gflops) for v in agg_gflops), f"{agg_gflops}"
|
||||
|
||||
def test_dedup(self):
|
||||
with save_viz() as viz:
|
||||
for _ in range(CNT:=4):
|
||||
Tensor.empty(4, device="NULL").add(1).realize()
|
||||
Tensor.empty(8, device="NULL").add(1).realize()
|
||||
with write_files(viz) as files, Context(NO_COLOR=1):
|
||||
name = run_cli(*files, "-s", "NULL")[0]["name"]
|
||||
with Context(DEBUG=3):
|
||||
select = run_cli(*files, "-s", "NULL", name)
|
||||
self.assertEqual(len([s for s in select if s.get("value")]), 1, "debug output was not deduped")
|
||||
self.assertEqual(len([s for s in select if s.get("device") == "NULL"]), CNT, f"expected 4 runs for {name}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -5,8 +5,7 @@ if hasattr(signal, "SIGPIPE"): signal.signal(signal.SIGPIPE, signal.SIG_DFL)
|
|||
from typing import Iterator
|
||||
from tinygrad.viz import serve as viz
|
||||
from tinygrad.uop.ops import RewriteTrace
|
||||
from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen, ProfilePointEvent, ProfileRangeEvent, TracingKey, unwrap, NO_COLOR
|
||||
from tinygrad.helpers import DEBUG, Context
|
||||
from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen, ProfilePointEvent, ProfileRangeEvent, TracingKey, unwrap, NO_COLOR, DEBUG
|
||||
|
||||
# profile decoder used in CLI and tests
|
||||
def decode_profile(data:bytes) -> dict:
|
||||
|
|
@ -83,15 +82,15 @@ def main(args) -> None:
|
|||
profile = decode_profile(profile_bytes)
|
||||
profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz_data.ctxs
|
||||
if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))])
|
||||
if args.list and not args.src: return print("ALL\n"+"\n".join(fmt_colored(k) for k in profile["layout"]))
|
||||
if args.list and not args.src: return print("\n".join(emit(fmt_colored(k)) for k in ["ALL"]+list(profile["layout"])))
|
||||
|
||||
# ** SQTT printer
|
||||
data = None if not args.src else get(profile["layout"], args.src[0])
|
||||
if args.src and "SQTT" in args.src[0]:
|
||||
# modern terminals support 24-bit color
|
||||
def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m"
|
||||
print(f"{'Clk':<12} {'Unit':<20} {'Op':<22} {'Dur':<4} {'Delay':<4} {'Info'}")
|
||||
print("-" * 100)
|
||||
print(emit(f"{'Clk':<12} {'Unit':<20} {'Op':<22} {'Dur':<4} {'Delay':<4} {'Info'}"))
|
||||
print(emit("-" * 100))
|
||||
pc_map:dict[int, str] = {}
|
||||
pkt_idxs:dict[str, itertools.count] = {}
|
||||
dispatch_to_inst:dict[str, tuple[str, int]] = {}
|
||||
|
|
@ -188,6 +187,7 @@ def main(args) -> None:
|
|||
fmt_row = fmt_top if args.t else fmt_all
|
||||
seen_refs:set[int] = set()
|
||||
def render_event(k:dict, ls=args.list) -> None:
|
||||
if len(args.src) > 1 and ansistrip(k["name"]) not in args.src: return None
|
||||
print(emit(k, to_str=fmt_row))
|
||||
if k["ref"] is not None and k["ref"] not in seen_refs:
|
||||
seen_refs.add(k["ref"])
|
||||
|
|
@ -196,14 +196,9 @@ def main(args) -> None:
|
|||
if DEBUG >= 4 and s["name"] == "View Source": print_step(s)
|
||||
if DEBUG >= 5 or ls: print(emit(" "*s["depth"]+s["name"]+(f" - {s['match_count']}" if s.get('match_count', 0) else '')))
|
||||
if DEBUG >= 6: print_step(s)
|
||||
if DEBUG >= 7 or (len(args.src) > 2 and s["name"] == args.src[2]): print_step(s, reconstruct_matches=True)
|
||||
if DEBUG >= 7 or s["name"] in args.src: print_step(s, reconstruct_matches=True)
|
||||
elif DEBUG >= 3 and k.get("ext"): print(emit(k["ext"]))
|
||||
produce = produce_top_kernels if args.t else produce_all_kernels
|
||||
if len(args.src) > 1:
|
||||
k = get({r["name"]:r for r in produce()}, args.src[1])
|
||||
with Context(DEBUG=max(DEBUG.value, 3)): render_event(k, ls=True)
|
||||
else:
|
||||
for k in produce(): render_event(k)
|
||||
for k in (produce_top_kernels if args.t else produce_all_kernels)(): render_event(k)
|
||||
|
||||
def get_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(prog="python -m tinygrad.viz.cli")
|
||||
|
|
|
|||
|
|
@ -342,7 +342,7 @@ def load_amd_counters(data:VizData, profile:list) -> None:
|
|||
for e in sqtt:
|
||||
if e.itrace: steps.append(create_step(f"SE:{e.se} PKTS", (f"/sqtt-{e.se}",len(data.ctxs),len(steps)), data=(e.blob,prg_events[k].lib,arch)))
|
||||
try:
|
||||
from extra.sqtt.roc import unpack_occ
|
||||
with Context(DEBUG=0): from extra.sqtt.roc import unpack_occ
|
||||
steps.append(create_step("OCC", ("/amd-sqtt-occ", len(data.ctxs), len(steps)),
|
||||
data={"fxn":unpack_occ, "args":((k, tag), sqtt, prg_events[k], arch)}))
|
||||
except Exception: pass
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue