mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
47 lines
2.3 KiB
Python
Executable file
47 lines
2.3 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
# Usage: DEBUG=5 python -m tinygrad.viz.cli --json | ./extra/viz/kernel_graph.py > /tmp/kernel_graph.txt
|
|
import argparse, json, sys, itertools
|
|
from tinygrad.helpers import ansistrip
|
|
from tinygrad.viz.cli import fmt_all
|
|
|
|
def get_node(graph:dict, key): return graph[str(key)]
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="print CALL graph from DEBUG=5 tinygrad.viz.cli --json output")
|
|
parser.add_argument("kernel", type=str, nargs="?", default="ALL", metavar="NAME", help="Kernel name to stop at (default: print all kernels)")
|
|
args = parser.parse_args()
|
|
ref:int|None = None
|
|
sched_counter = itertools.count(0)
|
|
for line in sys.stdin:
|
|
if not line.strip(): continue
|
|
graph = json.loads(line)
|
|
if graph.get("ref") is not None and (args.kernel == "ALL" or graph["ref"] == ref):
|
|
print(fmt_all(graph), f"ref={graph.get('ref')}")
|
|
if (v:=json.loads(next(sys.stdin, "{}")).get("value")): print(v)
|
|
if ref is not None or not isinstance(rec:=next(iter(graph.values()), {}), dict) or "label" not in rec: continue
|
|
sched_num = next(sched_counter)
|
|
unique:dict[int, int] = {}
|
|
for v in graph.values():
|
|
if not v["label"].startswith("CALL"): continue
|
|
lines = v["label"].splitlines()
|
|
# print the CALL and its kernel name from codegen
|
|
print(f"{lines[0]:<12} {lines[-1]}")
|
|
# print sources (buffer, param, multi)
|
|
for i,(_,s) in enumerate(v["src"][1:]):
|
|
while get_node(graph, s)["label"].startswith("AFTER"): s = get_node(graph, s)["src"][0][1]
|
|
if (num:=unique.get(s)) is None: unique[s] = num = len(unique)
|
|
print(f"SRC {i} id={num}-{sched_num} {' '.join(get_node(graph, s)['label'].splitlines())}")
|
|
# print access patterns
|
|
ss = [v["src"][0][1]]
|
|
seen:set[int] = set()
|
|
while ss:
|
|
if (s:=ss.pop()) in seen: continue
|
|
seen.add(s)
|
|
if get_node(graph, s)["label"].startswith("INDEX"):
|
|
idx_str = get_node(graph, s)["label"].splitlines()
|
|
src_str = ["SRC"]+get_node(graph, get_node(graph, s)["src"][0][1])["label"].splitlines()[1:]
|
|
print(" ".join(idx_str+src_str))
|
|
ss += [x[1] for x in get_node(graph, s)["src"]]
|
|
if args.kernel != "ALL" and args.kernel in ansistrip(v["label"]):
|
|
ref = v["ref"]
|
|
break
|