mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
viz/cli: add --interval (#16363)
* interval support * add test_interval * llama uses interval
This commit is contained in:
parent
2ab90f31b1
commit
b73d2d17b9
3 changed files with 28 additions and 2 deletions
|
|
@ -3,4 +3,4 @@ export BENCHMARK=5
|
|||
export EVAL_BS=0
|
||||
VIZ=${VIZ:--1} FULL_LAYERS=1 DEBUG=0 examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama31_8b/implementations/tinybox_8xMI350X/dev_beam.sh
|
||||
SRC="AMD"; [[ $DEV == NULL* ]] && SRC="NULL"
|
||||
python -m tinygrad.viz.cli -s "$SRC" -t
|
||||
python -m tinygrad.viz.cli -s "$SRC" -t --interval "train @ 2" "train @ 3"
|
||||
|
|
|
|||
|
|
@ -1008,5 +1008,22 @@ class TestCLI(unittest.TestCase):
|
|||
for i,n in enumerate(call_nodes):
|
||||
assert prgs[i] in n["label"], f"CALL must contain kernel name, got {n['label']}"
|
||||
|
||||
def test_interval(self):
|
||||
def emit_kernel(name:str): Tensor.custom_kernel(Tensor.empty(1, device="NULL"), fxn=lambda _: UOp.sink(arg=KernelInfo(name=name)))[0].realize()
|
||||
with save_viz() as viz:
|
||||
emit_kernel("pre_1")
|
||||
emit_kernel("pre_2")
|
||||
profile_marker("interval_start")
|
||||
emit_kernel("target_1")
|
||||
emit_kernel("target_2")
|
||||
profile_marker("interval_end")
|
||||
emit_kernel("post_1")
|
||||
emit_kernel("post_2")
|
||||
with write_files(viz) as files, Context(NO_COLOR=1):
|
||||
flat = run_cli(*files, "-s", "NULL", "--interval", "interval_start", "interval_end")
|
||||
aggregate = run_cli(*files, "-s", "NULL", "--interval", "interval_start", "interval_end", "-t")
|
||||
self.assertEqual([s["name"] for s in flat], ["interval_start", "target_1", "target_2", "interval_end"])
|
||||
self.assertEqual(sorted(s["name"] for s in aggregate), ["target_1", "target_2"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -52,6 +52,10 @@ def to_str(k:str, v) -> str:
|
|||
return f"{k}={v}"
|
||||
def fmt_data(data:dict) -> str: return " ".join((p:=to_str(k, v))+" "*max(0, 14-ansilen(p)) for k,v in data.items())
|
||||
|
||||
def marker_st(markers:list[dict], name:str) -> int:
|
||||
try: return next(e["ts"] for e in markers if e["name"] == name)
|
||||
except StopIteration: raise RuntimeError(f"marker not found: {name}") from None
|
||||
|
||||
def get(data:dict, key:str):
|
||||
for k,v in data.items():
|
||||
if ansistrip(k) == key: return v
|
||||
|
|
@ -140,12 +144,15 @@ def main(args) -> None:
|
|||
# ** Profiler printer
|
||||
else:
|
||||
timelines = [(n,l) for n,l in profile["layout"].items() if isinstance(l, dict) and l.get("event_type") == 0]
|
||||
markers = profile.get("markers", [])
|
||||
interval:tuple[int, int]|None = None if not args.interval else (marker_st(markers, args.interval[0]), marker_st(markers, args.interval[1]))
|
||||
def produce_top_kernels() -> Iterator[dict]:
|
||||
tagged = ((n,e) for n,l in timelines for e in l["events"]) if not args.src else ((args.src[0],e) for e in unwrap(data)["events"])
|
||||
agg:dict[tuple[str,str], tuple[float, int, int|None, dict[str, float]]] = {} # map (device, kernel name) to (total time, count, ref, est)
|
||||
est_keys = ("FLOPS", "B/s mem", "B/s lds")
|
||||
total = 0
|
||||
for dev,e in tagged:
|
||||
if interval and not interval[0] <= e["st"] <= interval[1]: continue
|
||||
et = e["dur"] * 1e-3
|
||||
t, c, ref, est = agg.get((dev,e["name"]), (0.0, 0, None, {}))
|
||||
est.update({k:est.get(k, 0.0)+e["fmt"][k]*e["dur"]*1e-6 for k in est_keys if k in e["fmt"]})
|
||||
|
|
@ -166,8 +173,9 @@ def main(args) -> None:
|
|||
if not args.src:
|
||||
for n,l in profile["layout"].items():
|
||||
if not isinstance(l, dict) or l.get("event_type") != 0: yield {"device":"SOURCE", "name":n, "st_ms":0, "ref":None, "ext":None}
|
||||
marker_stream = sorted([(m["ts"], "MARKER", m) for m in profile.get("markers", [])], key=lambda t:t[0])
|
||||
marker_stream = sorted([(m["ts"], "MARKER", m) for m in markers], key=lambda t:t[0])
|
||||
for ts,dev,e in heapq.merge(*event_streams, marker_stream, key=lambda t:t[0]):
|
||||
if interval is not None and not interval[0] <= ts <= interval[1]: continue
|
||||
if dev == "MARKER":
|
||||
yield {"device":dev, "name":fmt_colored(e["name"]), "st_ms":ts*1e-3, "ref":None, "ext":None}
|
||||
continue
|
||||
|
|
@ -209,6 +217,7 @@ def get_arg_parser() -> argparse.ArgumentParser:
|
|||
parser = argparse.ArgumentParser(prog="python -m tinygrad.viz.cli")
|
||||
parser.add_argument("-s", "--src", nargs="+", default=[], metavar="NAME", help="Select a data source (default: all)")
|
||||
parser.add_argument("--list", "--ls", dest="list", action="store_true", help="List sources")
|
||||
parser.add_argument("--interval", nargs=2, metavar=("START", "END"), help="Optional start and end marker")
|
||||
parser.add_argument("-t", nargs="?", type=int, const=20, metavar="COUNT", help="Aggregate top kernels (optional count, default 20)")
|
||||
parser.add_argument("--profile-path", type=str, metavar="PATH", help="Optional path to profile.pkl (default: latest profile)",
|
||||
default=temp("profile.pkl", append_user=True))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue