mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
viz/cli: add DEBUG, optional number of rows (#15777)
* tabulate switch * support DEBUG * --top * improve * work * feedback * 0 * print_kernel both ways * simplify
This commit is contained in:
parent
2d196fb9bb
commit
0e69388f6b
2 changed files with 33 additions and 18 deletions
|
|
@ -13,8 +13,13 @@ Flags: VIZ=-1 to only save the trace to a file, VIZ=1 also launches a web server
|
|||
|
||||
Use `extra/viz/cli.py --profile` to list all sources.
|
||||
|
||||
List top slowest kernels on a source: `--profile -s "AMD"`
|
||||
List samples of a kernel on a source: `--profile -s "AMD" -i E_3 | head 4`
|
||||
```bash
|
||||
# View top 40 slowest kernels and their AST (DEBUG=4 to see source code)
|
||||
DEBUG=3 extra/viz/cli.py --profile -s AMD --top 40
|
||||
|
||||
# View all runs of a kernel
|
||||
extra/viz/cli.py --profile -s AMD -i E_3 | head 4`
|
||||
```
|
||||
|
||||
## Inspect codegen and PatternMatcher
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Iterator
|
|||
from tinygrad.viz import serve as viz
|
||||
from tinygrad.uop.ops import RewriteTrace
|
||||
from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen, ProfilePointEvent, ProfileRangeEvent, TracingKey, unwrap, NO_COLOR
|
||||
from tinygrad.helpers import DEBUG
|
||||
|
||||
# profile decoder used in CLI and tests
|
||||
def decode_profile(data:bytes) -> dict:
|
||||
|
|
@ -64,8 +65,8 @@ def main(args) -> None:
|
|||
profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz_data.ctxs
|
||||
if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))])
|
||||
if args.src is None:
|
||||
for k in profile["layout"]:
|
||||
print(f" {format_colored(k)}")
|
||||
print("Select a source with -s")
|
||||
for k in profile["layout"]: print(f" {format_colored(k)}")
|
||||
return None
|
||||
|
||||
# ** SQTT printer
|
||||
|
|
@ -111,8 +112,10 @@ def main(args) -> None:
|
|||
elif args.item == r[0]:
|
||||
rows = r[2]["rows"] if len(r) > 2 else [r[:2]]
|
||||
cols = r[2]["cols"] if len(r) > 2 else cols
|
||||
from tabulate import tabulate
|
||||
print(tabulate(rows, headers=cols, tablefmt="github"))
|
||||
data = [[x for x in cols], *[[str(x) for x in r] for r in rows]]
|
||||
widths = [max(len(r[i]) for r in data) for i in range(len(cols))]
|
||||
def fmt(r): return "| "+" | ".join(x+" "*(w-len(x)) for x,w in zip(r, widths))+" |"
|
||||
print(fmt(data[0])+"\n"+fmt(["-"*w for w in widths])+"\n"+("\n".join([fmt(row) for row in data[1:]])))
|
||||
return None
|
||||
|
||||
# ** Memory printer
|
||||
|
|
@ -127,8 +130,11 @@ def main(args) -> None:
|
|||
return None
|
||||
|
||||
# ** Profiler printer
|
||||
agg:dict[str, tuple[float, int]] = {}
|
||||
total = 0
|
||||
agg:dict[str, tuple[float, int, int|None]] = {}
|
||||
total, first = 0, True
|
||||
def print_kernel(ref:int) -> None:
|
||||
if DEBUG >= 3: print(viz._reconstruct(viz_data, viz_data.trace.rewrites[ref][0].sink).pyrender())
|
||||
if DEBUG >= 4: print(viz_data.ctxs[ref]["prg"].src[3].arg)
|
||||
for e in data.get("events", []):
|
||||
et = e["dur"] * 1e-6
|
||||
if args.item is not None:
|
||||
|
|
@ -136,20 +142,23 @@ def main(args) -> None:
|
|||
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None)
|
||||
name = e["name"] + (" " * (46 - ansilen(e["name"])))
|
||||
print(f"{format_colored(name)} {ptm}/{et*1e3:9.2f}ms " + e.get("fmt", "").replace("\n", " | ") + " ")
|
||||
if first:
|
||||
if e["ref"] is not None: print_kernel(e["ref"])
|
||||
first = False
|
||||
else:
|
||||
t, c = agg.get(e["name"], (0.0, 0))
|
||||
agg[e["name"]] = (t+et, c+1)
|
||||
t, c, ref = agg.get(e["name"], (0.0, 0, None))
|
||||
agg[e["name"]] = (t+et, c+1, e["ref"])
|
||||
total += et
|
||||
if agg and total > 0:
|
||||
from tabulate import tabulate
|
||||
items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True)
|
||||
num_rows = 20
|
||||
table = [[format_colored(name), time_to_str(t, w=9), c, f"{(t/total*100.0):.2f}%"] for name,(t,c) in items[:num_rows]]
|
||||
if items[num_rows:]:
|
||||
other_t = sum(t for _,(t,_) in items[num_rows:])
|
||||
other_c = sum(c for _,(_,c) in items[num_rows:])
|
||||
table.append(["Other", time_to_str(other_t, w=9), other_c, f"{(other_t/total*100.0):.2f}%"])
|
||||
print(tabulate(table, headers=["name", "total", "count", "pct"], tablefmt="github"))
|
||||
num_rows = args.top
|
||||
for name,(t,c,ref) in items[:num_rows]:
|
||||
print(f"{format_colored(name)}{' ' * max(0, 36 - ansilen(name))} {time_to_str(t, w=9)} {c:7d} {t/total*100.0:6.2f}%")
|
||||
if ref is not None: print_kernel(ref)
|
||||
if num_rows > 0 and items[num_rows:]:
|
||||
other_t = sum(t for _,(t,_,_) in items[num_rows:])
|
||||
other_c = sum(c for _,(_,c,_) in items[num_rows:])
|
||||
print(f"{'Other':<36} {time_to_str(other_t, w=9)} {other_c:7d} {other_t/total*100.0:6.2f}%")
|
||||
return None
|
||||
|
||||
# ** Graph rewrites printer
|
||||
|
|
@ -180,6 +189,7 @@ def get_arg_parser() -> argparse.ArgumentParser:
|
|||
g_opts = parser.add_argument_group("optional args")
|
||||
g_opts.add_argument("-s", "--src", type=str, default=None, metavar="NAME", help="Select a data source (default: list all sources)")
|
||||
g_opts.add_argument("-i", "--item", type=str, default=None, metavar="NAME", help="Select an item within the source (default: list all items)")
|
||||
g_opts.add_argument("--top", type=int, default=20, metavar="COUNT", help="Number of top rows to print (default: 20, set -1 to print all)")
|
||||
g_opts.add_argument("--profile-path", type=pathlib.Path, metavar="PATH", help="Path to profile.pkl (optional file, default: latest profile)",
|
||||
default=pathlib.Path(temp("profile.pkl", append_user=True)))
|
||||
g_opts.add_argument("--rewrites-path", type=pathlib.Path, metavar="PATH", help="Path to rewrites.pkl (optional file, default: latest rewrites)",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue