viz/cli: add --interval (#16363)

* interval support

* add test_interval

* llama uses interval
This commit is contained in:
qazal 2026-05-25 21:35:06 +03:00 committed by GitHub
commit b73d2d17b9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 28 additions and 2 deletions

View file

@ -3,4 +3,4 @@ export BENCHMARK=5
export EVAL_BS=0
VIZ=${VIZ:--1} FULL_LAYERS=1 DEBUG=0 examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama31_8b/implementations/tinybox_8xMI350X/dev_beam.sh
SRC="AMD"; [[ $DEV == NULL* ]] && SRC="NULL"
python -m tinygrad.viz.cli -s "$SRC" -t
python -m tinygrad.viz.cli -s "$SRC" -t --interval "train @ 2" "train @ 3"

View file

@ -1008,5 +1008,22 @@ class TestCLI(unittest.TestCase):
for i,n in enumerate(call_nodes):
assert prgs[i] in n["label"], f"CALL must contain kernel name, got {n['label']}"
def test_interval(self):
def emit_kernel(name:str): Tensor.custom_kernel(Tensor.empty(1, device="NULL"), fxn=lambda _: UOp.sink(arg=KernelInfo(name=name)))[0].realize()
with save_viz() as viz:
emit_kernel("pre_1")
emit_kernel("pre_2")
profile_marker("interval_start")
emit_kernel("target_1")
emit_kernel("target_2")
profile_marker("interval_end")
emit_kernel("post_1")
emit_kernel("post_2")
with write_files(viz) as files, Context(NO_COLOR=1):
flat = run_cli(*files, "-s", "NULL", "--interval", "interval_start", "interval_end")
aggregate = run_cli(*files, "-s", "NULL", "--interval", "interval_start", "interval_end", "-t")
self.assertEqual([s["name"] for s in flat], ["interval_start", "target_1", "target_2", "interval_end"])
self.assertEqual(sorted(s["name"] for s in aggregate), ["target_1", "target_2"])
if __name__ == "__main__":
unittest.main()

View file

@ -52,6 +52,10 @@ def to_str(k:str, v) -> str:
return f"{k}={v}"
def fmt_data(data:dict) -> str: return " ".join((p:=to_str(k, v))+" "*max(0, 14-ansilen(p)) for k,v in data.items())
def marker_st(markers:list[dict], name:str) -> int:
try: return next(e["ts"] for e in markers if e["name"] == name)
except StopIteration: raise RuntimeError(f"marker not found: {name}") from None
def get(data:dict, key:str):
for k,v in data.items():
if ansistrip(k) == key: return v
@ -140,12 +144,15 @@ def main(args) -> None:
# ** Profiler printer
else:
timelines = [(n,l) for n,l in profile["layout"].items() if isinstance(l, dict) and l.get("event_type") == 0]
markers = profile.get("markers", [])
interval:tuple[int, int]|None = None if not args.interval else (marker_st(markers, args.interval[0]), marker_st(markers, args.interval[1]))
def produce_top_kernels() -> Iterator[dict]:
tagged = ((n,e) for n,l in timelines for e in l["events"]) if not args.src else ((args.src[0],e) for e in unwrap(data)["events"])
agg:dict[tuple[str,str], tuple[float, int, int|None, dict[str, float]]] = {} # map (device, kernel name) to (total time, count, ref, est)
est_keys = ("FLOPS", "B/s mem", "B/s lds")
total = 0
for dev,e in tagged:
if interval and not interval[0] <= e["st"] <= interval[1]: continue
et = e["dur"] * 1e-3
t, c, ref, est = agg.get((dev,e["name"]), (0.0, 0, None, {}))
est.update({k:est.get(k, 0.0)+e["fmt"][k]*e["dur"]*1e-6 for k in est_keys if k in e["fmt"]})
@ -166,8 +173,9 @@ def main(args) -> None:
if not args.src:
for n,l in profile["layout"].items():
if not isinstance(l, dict) or l.get("event_type") != 0: yield {"device":"SOURCE", "name":n, "st_ms":0, "ref":None, "ext":None}
marker_stream = sorted([(m["ts"], "MARKER", m) for m in profile.get("markers", [])], key=lambda t:t[0])
marker_stream = sorted([(m["ts"], "MARKER", m) for m in markers], key=lambda t:t[0])
for ts,dev,e in heapq.merge(*event_streams, marker_stream, key=lambda t:t[0]):
if interval is not None and not interval[0] <= ts <= interval[1]: continue
if dev == "MARKER":
yield {"device":dev, "name":fmt_colored(e["name"]), "st_ms":ts*1e-3, "ref":None, "ext":None}
continue
@ -209,6 +217,7 @@ def get_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="python -m tinygrad.viz.cli")
parser.add_argument("-s", "--src", nargs="+", default=[], metavar="NAME", help="Select a data source (default: all)")
parser.add_argument("--list", "--ls", dest="list", action="store_true", help="List sources")
parser.add_argument("--interval", nargs=2, metavar=("START", "END"), help="Optional start and end marker")
parser.add_argument("-t", nargs="?", type=int, const=20, metavar="COUNT", help="Aggregate top kernels (optional count, default 20)")
parser.add_argument("--profile-path", type=str, metavar="PATH", help="Optional path to profile.pkl (default: latest profile)",
default=temp("profile.pkl", append_user=True))