mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
viz: no global state (#15705)
* start viz data * get_full_rewrites also moves * update ref_map * work * update consumers * cleaner cli * linter * cleanup tests * back * better * sqtt tests
This commit is contained in:
parent
4c1fb18a09
commit
ac027055ef
5 changed files with 111 additions and 99 deletions
|
|
@ -136,7 +136,8 @@ def print_data(data:dict) -> None:
|
|||
|
||||
def main() -> None:
|
||||
import tinygrad.viz.serve as viz
|
||||
viz.ctxs = []
|
||||
from tinygrad.uop.ops import RewriteTrace
|
||||
data = viz.VizData()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--profile', type=pathlib.Path, metavar="PATH", help='Path to profile (optional file, default: latest profile)',
|
||||
|
|
@ -147,24 +148,24 @@ def main() -> None:
|
|||
|
||||
with args.profile.open("rb") as f: profile = pickle.load(f)
|
||||
|
||||
viz.get_profile(profile)
|
||||
viz.get_profile(profile, data=data)
|
||||
|
||||
# List all kernels
|
||||
if args.kernel is None:
|
||||
for c in viz.ctxs:
|
||||
for c in data.ctxs:
|
||||
print(c["name"])
|
||||
for s in c["steps"]: print(" "+s["name"])
|
||||
return None
|
||||
|
||||
# Find kernel trace
|
||||
trace = next((c for c in viz.ctxs if c["name"] == f"Exec {args.kernel}"), None)
|
||||
trace = next((c for c in data.ctxs if c["name"] == f"SQTT {args.kernel}"), None)
|
||||
if not trace: raise RuntimeError(f"no matching trace for {args.kernel}")
|
||||
n = 0
|
||||
for s in trace["steps"]:
|
||||
if "PKTS" in s["name"]: continue
|
||||
print(s["name"])
|
||||
data = viz.get_render(s["query"])
|
||||
print_data(data)
|
||||
ret = viz.get_render(data, s["query"])
|
||||
print_data(ret)
|
||||
n += 1
|
||||
if n > args.n: break
|
||||
|
||||
|
|
|
|||
|
|
@ -52,16 +52,16 @@ def get(data:dict, key:str):
|
|||
raise RuntimeError(f'item "{key}" not found in list'+(f", did you mean {match[0]!r}?" if match else ''))
|
||||
|
||||
def main(args) -> None:
|
||||
viz.trace = viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {}))
|
||||
viz.ctxs = viz.get_rewrites(viz.trace)
|
||||
viz.data = viz.VizData(viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {})))
|
||||
viz.load_rewrites(viz.data)
|
||||
|
||||
def format_colored(s:str) -> str: return ansistrip(s) if args.no_color else s
|
||||
|
||||
if args.profile:
|
||||
events:list = viz.load_pickle(args.profile_path, default=[])
|
||||
if (profile_bytes:=viz.get_profile(events)) is None: raise RuntimeError(f"empty profile in {args.profile_path}")
|
||||
if (profile_bytes:=viz.get_profile(events, data=viz.data)) is None: raise RuntimeError(f"empty profile in {args.profile_path}")
|
||||
profile = decode_profile(profile_bytes)
|
||||
profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz.ctxs
|
||||
profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz.data.ctxs
|
||||
if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))])
|
||||
if args.src is None:
|
||||
for k in profile["layout"]:
|
||||
|
|
@ -142,7 +142,7 @@ def main(args) -> None:
|
|||
return None
|
||||
|
||||
# ** Graph rewrites printer
|
||||
rewrites = {c["name"]:{s["name"]:s for s in c["steps"]} for c in viz.ctxs if c.get("steps")}
|
||||
rewrites = {c["name"]:{s["name"]:s for s in c["steps"]} for c in viz.data.ctxs if c.get("steps")}
|
||||
if args.src is None:
|
||||
for k in rewrites: print(f" {format_colored(k)}")
|
||||
return None
|
||||
|
|
@ -150,7 +150,7 @@ def main(args) -> None:
|
|||
if args.item is None:
|
||||
for k,v in steps.items(): print(" "*v["depth"]+k+(f" - {v['match_count']}" if v.get('match_count', 0) else ''))
|
||||
else:
|
||||
data = viz.get_render(get(steps, args.item)["query"])
|
||||
data = viz.get_render(viz.data, get(steps, args.item)["query"])
|
||||
if isinstance(data.get("value"), Iterator):
|
||||
for m in data["value"]:
|
||||
if m.get("uop"): print(f"Input UOp:\n{m['uop']}")
|
||||
|
|
|
|||
|
|
@ -1,15 +1,16 @@
|
|||
import unittest, contextlib
|
||||
from tinygrad import Device, Tensor, Context, TinyJit
|
||||
from tinygrad.device import Compiled, ProfileProgramEvent, ProfileDeviceEvent
|
||||
from tinygrad.viz.serve import load_amd_counters
|
||||
from tinygrad.viz.serve import load_amd_counters, VizData
|
||||
|
||||
@contextlib.contextmanager
|
||||
def save_sqtt():
|
||||
yield (ret:=[])
|
||||
data = VizData()
|
||||
yield data.ctxs
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
Device[Device.DEFAULT]._at_profile_finalize()
|
||||
load_amd_counters(ret, Compiled.profile_events)
|
||||
ret[:] = [r for r in ret if r["name"].startswith("SQTT")]
|
||||
load_amd_counters(data, Compiled.profile_events)
|
||||
data.ctxs[:] = [r for r in data.ctxs if r["name"].startswith("SQTT")]
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "AMD", "only runs on AMD")
|
||||
class TestSQTTProfiler(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from tinygrad.helpers import VIZ, cpu_profile, ProfilePointEvent, unwrap
|
|||
from tinygrad.device import Buffer
|
||||
|
||||
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, active_group, _name_cnt, RewriteTrace
|
||||
from tinygrad.viz.serve import get_rewrites, get_full_rewrite, uop_to_json
|
||||
from tinygrad.viz.serve import load_rewrites, get_full_rewrite, uop_to_json, VizData
|
||||
|
||||
@track_rewrites(name=True)
|
||||
def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=None) -> UOp:
|
||||
|
|
@ -21,19 +21,19 @@ def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=Non
|
|||
# small container class for the viz server module
|
||||
class VizTrace:
|
||||
# loader init
|
||||
def __init__(self): self._trace:RewriteTrace|None = None
|
||||
def __init__(self): self._data:VizData|None = None
|
||||
@property
|
||||
def trace(self) -> RewriteTrace: return unwrap(self._trace)
|
||||
def set_trace(self) -> None:
|
||||
self._trace = RewriteTrace(tracked_keys.copy(), tracked_ctxs.copy(), uop_fields.copy())
|
||||
import tinygrad.viz.serve as serve_module
|
||||
serve_module.trace = self._trace
|
||||
def data(self) -> VizData: return unwrap(self._data)
|
||||
def set_data(self) -> None:
|
||||
data = VizData(RewriteTrace(tracked_keys.copy(), tracked_ctxs.copy(), uop_fields.copy()))
|
||||
load_rewrites(data)
|
||||
self._data = data
|
||||
# the API
|
||||
def list_items(self) -> list[dict]: return get_rewrites(self.trace)
|
||||
def list_items(self) -> list[dict]:
|
||||
return self.data.ctxs
|
||||
def get_details(self, rewrite_idx:int, step:int) -> Generator[dict, None, None]:
|
||||
lst = self.list_items()
|
||||
assert len(lst) > rewrite_idx, f"only loaded {len(lst)} traces, expecting at least {rewrite_idx}"
|
||||
return get_full_rewrite(self.trace.rewrites[rewrite_idx][step])
|
||||
assert len(self.data.trace.rewrites) > rewrite_idx, f"only loaded {len(self.data.trace.rewrites)} traces, expecting at least {rewrite_idx}"
|
||||
return get_full_rewrite(self.data, self.data.trace.rewrites[rewrite_idx][step])
|
||||
|
||||
@contextlib.contextmanager
|
||||
def save_viz():
|
||||
|
|
@ -52,7 +52,7 @@ def save_viz():
|
|||
try:
|
||||
yield viz
|
||||
finally:
|
||||
viz.set_trace()
|
||||
viz.set_data()
|
||||
TRACK_MATCH_STATS.value = prev_tms
|
||||
PROFILE.value = prev_profile
|
||||
VIZ.value = prev_viz
|
||||
|
|
@ -194,7 +194,7 @@ class TestViz(unittest.TestCase):
|
|||
class TestStruct:
|
||||
colored_field: str
|
||||
a = UOp(Ops.CUSTOM, arg=TestStruct(colored("xyz", "magenta")+colored("12345", "blue")))
|
||||
a2 = uop_to_json(a)[id(a)]
|
||||
a2 = uop_to_json(a, VizData())[id(a)]
|
||||
self.assertEqual(ansistrip(a2["label"]), f"CUSTOM\n{TestStruct.__qualname__}(colored_field='xyz12345')")
|
||||
|
||||
def test_colored_label_multiline(self):
|
||||
|
|
@ -217,11 +217,11 @@ class TestViz(unittest.TestCase):
|
|||
# use smaller stack limit for faster test (default is 250000)
|
||||
with Context(REWRITE_STACK_LIMIT=100): self.assertRaises(RuntimeError, exec_rewrite, a, [pm])
|
||||
graphs = flatten(x["graph"].values() for x in viz.get_details(0, 0))
|
||||
self.assertEqual(graphs[0], uop_to_json(a)[id(a)])
|
||||
self.assertEqual(graphs[1], uop_to_json(b)[id(b)])
|
||||
self.assertEqual(graphs[0], uop_to_json(a, VizData())[id(a)])
|
||||
self.assertEqual(graphs[1], uop_to_json(b, VizData())[id(b)])
|
||||
# fallback to NOOP with the error message
|
||||
nop = UOp(Ops.NOOP, arg="infinite loop in fixed_point_rewrite")
|
||||
self.assertEqual(graphs[2], uop_to_json(nop)[id(nop)])
|
||||
self.assertEqual(graphs[2], uop_to_json(nop, VizData())[id(nop)])
|
||||
|
||||
def test_const_node_visibility(self):
|
||||
with save_viz() as viz:
|
||||
|
|
@ -241,7 +241,7 @@ class TestViz(unittest.TestCase):
|
|||
c = UOp.const(dtypes.float, 1.0, device="CPU", shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain
|
||||
a = UOp(Ops.DEFINE_VAR, dtypes.float, arg=("a", 0.0, 10.0))
|
||||
alu = a + c
|
||||
graph = uop_to_json(alu)
|
||||
graph = uop_to_json(alu, VizData())
|
||||
# the RESHAPE and EXPAND nodes from the const should not appear in the graph
|
||||
labels = {v["label"].split("\n")[0] for v in graph.values()}
|
||||
self.assertNotIn("RESHAPE", labels)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
#!/usr/bin/env python3
|
||||
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, functools, codecs, io, struct, re
|
||||
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, codecs, io, struct, re
|
||||
import pathlib, traceback, itertools, socketserver
|
||||
from contextlib import redirect_stdout, redirect_stderr, contextmanager
|
||||
from decimal import Decimal
|
||||
from dataclasses import dataclass, field
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from typing import Any, TypedDict, TypeVar, Generator, Callable
|
||||
|
|
@ -61,21 +62,27 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
|||
def create_step(name:str, query:tuple[str, int, int], data=None, depth:int=0, **kwargs) -> dict:
|
||||
return {"name":name, "query":f"{query[0]}?ctx={query[1]}&step={query[2]}", "data":data, "depth":depth, **kwargs}
|
||||
|
||||
# ** list all saved rewrites
|
||||
@dataclass(frozen=True)
|
||||
class VizData:
|
||||
trace:RewriteTrace = field(default_factory=lambda: RewriteTrace([], [], {}))
|
||||
ctxs:list[dict] = field(default_factory=list)
|
||||
ref_map:dict[Any, int] = field(default_factory=dict)
|
||||
all_uops:dict[int, UOp] = field(default_factory=dict)
|
||||
|
||||
ref_map:dict[Any, int] = {}
|
||||
def get_rewrites(t:RewriteTrace) -> list[dict]:
|
||||
ret = []
|
||||
for i,(k,v) in enumerate(zip(t.keys, t.rewrites)):
|
||||
# ** load all saved rewrites
|
||||
|
||||
def load_rewrites(data:VizData) -> None:
|
||||
assert not data.ctxs and not data.ref_map, "load_rewrites called multiple times"
|
||||
for i,k in enumerate(data.trace.keys):
|
||||
v = data.trace.rewrites[i]
|
||||
steps = [create_step(s.name, ("/graph-rewrites", i, j), loc=s.loc, match_count=len(s.matches), code_line=printable(s.loc),
|
||||
trace=k.tb if j==0 else None, depth=s.depth) for j,s in enumerate(v)]
|
||||
if (p:=get_prg_uop(i)) is not None:
|
||||
if (p:=get_prg_uop(data, i)) is not None:
|
||||
steps.append(create_step("View UOp List", ("/uops", i, len(steps))))
|
||||
steps.append(create_step("View Source", ("/code", i, len(steps)), p.src[3].arg))
|
||||
steps.append(create_step("View Disassembly", ("/asm", i, len(steps)), (k.ret, p.src[4].arg)))
|
||||
for key in k.keys: ref_map[key] = i
|
||||
ret.append({"name":k.display_name, "steps":steps})
|
||||
return ret
|
||||
for key in k.keys: data.ref_map[key] = i
|
||||
data.ctxs.append({"name":k.display_name, "steps":steps})
|
||||
|
||||
# ** get the complete UOp graphs for one rewrite
|
||||
|
||||
|
|
@ -93,9 +100,7 @@ def pystr(u:UOp) -> str:
|
|||
try: return pyrender(u)
|
||||
except Exception: return str(u)
|
||||
|
||||
# all the trace points, initialized after the trace loads
|
||||
ctxs:list[dict] = []
|
||||
def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
def uop_to_json(x:UOp, data:VizData) -> dict[int, dict]:
|
||||
assert isinstance(x, UOp)
|
||||
graph: dict[int, dict] = {}
|
||||
excluded: set[UOp] = set()
|
||||
|
|
@ -138,7 +143,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||
label += "\n"+' '.join([f"{range_str(s, color=True)}({s.vmax+1})" for s in trngs])
|
||||
except Exception:
|
||||
label += "\n<ISSUE GETTING LABEL>"
|
||||
if (ref:=ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None and ctxs: label += f"\ncodegen@{ctxs[ref]['name']}"
|
||||
if (ref:=data.ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None: label += f"\ncodegen@{data.ctxs[ref]['name']}"
|
||||
# NOTE: kernel already has metadata in arg
|
||||
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.CALL: label += "\n"+str(u.metadata)
|
||||
# limit SOURCE labels line count
|
||||
|
|
@ -148,28 +153,30 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None}
|
||||
return graph
|
||||
|
||||
@functools.cache
|
||||
def _reconstruct(a:int, depth:int|None=None):
|
||||
op, dtype, src, arg, *rest = trace.uop_fields[a]
|
||||
def _reconstruct(data:VizData, a:int, depth:int|None=None):
|
||||
if depth is None and a in data.all_uops: return data.all_uops[a]
|
||||
op, dtype, src, arg, *rest = data.trace.uop_fields[a]
|
||||
if depth is not None and depth <= 0: return UOp(op, dtype, (), arg, *rest)
|
||||
return UOp(op, dtype, tuple(_reconstruct(s, None if depth is None else depth-1) for s in src), arg, *rest)
|
||||
ret = UOp(op, dtype, tuple(_reconstruct(data, s, None if depth is None else depth-1) for s in src), arg, *rest)
|
||||
if depth is None: data.all_uops[a] = ret
|
||||
return ret
|
||||
|
||||
def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
|
||||
next_sink = _reconstruct(ctx.sink)
|
||||
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
|
||||
def get_full_rewrite(data:VizData, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
|
||||
next_sink = _reconstruct(data, ctx.sink)
|
||||
yield {"graph":uop_to_json(next_sink, data), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
|
||||
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
|
||||
replaces[u0:=_reconstruct(data, u0_num)] = u1 = _reconstruct(data, u1_num)
|
||||
try: new_sink = next_sink.substitute(replaces)
|
||||
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
|
||||
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
|
||||
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
||||
yield {"graph":(sink_json:=uop_to_json(new_sink, data)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
||||
"diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)}
|
||||
if not ctx.bottom_up: next_sink = new_sink
|
||||
|
||||
def get_prg_uop(i:int) -> UOp|None:
|
||||
s = next((s for s in trace.rewrites[i] if s.name == "View Program"), None)
|
||||
return _reconstruct(s.sink, depth=1) if s is not None else None
|
||||
def get_prg_uop(data:VizData, i:int) -> UOp|None:
|
||||
s = next((s for s in data.trace.rewrites[i] if s.name == "View Program"), None)
|
||||
return _reconstruct(data, s.sink, depth=1) if s is not None else None
|
||||
|
||||
# encoder helpers
|
||||
|
||||
|
|
@ -187,31 +194,30 @@ def rel_ts(ts:int|Decimal, start_ts:int, ctx:str="") -> int:
|
|||
|
||||
# Profiler API
|
||||
|
||||
device_ts_diffs:dict[str, Decimal] = {}
|
||||
def cpu_ts_diff(device:str) -> Decimal: return device_ts_diffs.get(device, Decimal(0))
|
||||
def cpu_ts_diff(device_ts_diffs:dict[str, Decimal], device:str) -> Decimal: return device_ts_diffs.get(device, Decimal(0))
|
||||
|
||||
DevEvent = ProfileRangeEvent|ProfileGraphEntry|ProfilePointEvent
|
||||
def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]:
|
||||
def flatten_events(profile:list[ProfileEvent], device_ts_diffs:dict[str, Decimal]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]:
|
||||
for e in profile:
|
||||
if isinstance(e, ProfileRangeEvent): yield (e.st+(diff:=cpu_ts_diff(e.device)), (e.en if e.en is not None else e.st)+diff, e)
|
||||
if isinstance(e, ProfileRangeEvent): yield (e.st+(diff:=cpu_ts_diff(device_ts_diffs, e.device)), (e.en if e.en is not None else e.st)+diff, e)
|
||||
elif isinstance(e, ProfilePointEvent): yield (e.ts, e.ts, e)
|
||||
elif isinstance(e, ProfileGraphEvent):
|
||||
cpu_ts = []
|
||||
for ent in e.ents: cpu_ts += [e.sigs[ent.st_id]+(diff:=cpu_ts_diff(ent.device)), e.sigs[ent.en_id]+diff]
|
||||
for ent in e.ents: cpu_ts += [e.sigs[ent.st_id]+(diff:=cpu_ts_diff(device_ts_diffs, ent.device)), e.sigs[ent.en_id]+diff]
|
||||
yield (st:=min(cpu_ts)), (et:=max(cpu_ts)), ProfileRangeEvent(f"{e.ents[0].device.split(':')[0]} Graph", f"batched {len(e.ents)}", st, et)
|
||||
for i,ent in enumerate(e.ents): yield (cpu_ts[i*2], cpu_ts[i*2+1], ent)
|
||||
|
||||
# normalize event timestamps and attach kernel metadata
|
||||
def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, scache:dict[str, int]) -> bytes|None:
|
||||
def timeline_layout(data:VizData, dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, scache:dict[str, int]) -> bytes|None:
|
||||
events:list[bytes] = []
|
||||
exec_points:dict[str, ProfilePointEvent] = {}
|
||||
for st,et,dur,e in dev_events:
|
||||
if isinstance(e, ProfilePointEvent) and e.name == "exec": exec_points[e.arg["name"]] = e
|
||||
if dur == 0: continue
|
||||
name, fmt, key = e.name, [], None
|
||||
if (ref:=ref_map.get(name)) is not None and ctxs:
|
||||
name = ctxs[ref]["name"]
|
||||
if (p:=get_prg_uop(ref)) is not None and (ei:=exec_points.get(p.src[0].arg.name)) is not None:
|
||||
if (ref:=data.ref_map.get(name)) is not None and ref < len(data.ctxs):
|
||||
name = data.ctxs[ref]["name"]
|
||||
if (p:=get_prg_uop(data, ref)) is not None and (ei:=exec_points.get(p.src[0].arg.name)) is not None:
|
||||
flops = sym_infer((estimates:=p.src[0].arg.estimates).ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6)
|
||||
membw, ldsbw = sym_infer(estimates.mem, var_vals)/t, sym_infer(estimates.lds, var_vals)/t
|
||||
fmt = [f"{flops*1e-9:.0f} GFLOPS" if flops < 1e14 else f"{flops*1e-12:.0f} TFLOPS",
|
||||
|
|
@ -222,7 +228,7 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:
|
|||
key = ei.key
|
||||
elif isinstance(e.name, TracingKey):
|
||||
name = e.name.display_name
|
||||
ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None)
|
||||
ref = next((v for k in e.name.keys if (v:=data.ref_map.get(k)) is not None), None)
|
||||
if isinstance(e.name.ret, str): fmt.append(e.name.ret)
|
||||
elif isinstance(e.name.ret, int):
|
||||
membw = (nbytes:=e.name.ret) / (dur * 1e-6)
|
||||
|
|
@ -313,7 +319,7 @@ def unpack_pmc(e) -> dict:
|
|||
|
||||
# ** on startup, list all the performance counter traces
|
||||
|
||||
def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None:
|
||||
def load_amd_counters(data:VizData, profile:list[ProfileEvent]) -> None:
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
|
||||
counter_events:dict[tuple[int, int], dict] = {}
|
||||
durations:dict[str, list[float]] = {}
|
||||
|
|
@ -326,22 +332,23 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None:
|
|||
if isinstance(e, ProfileProgramEvent) and e.tag is not None: prg_events[e.tag] = e
|
||||
if isinstance(e, ProfileDeviceEvent) and e.device.startswith("AMD"): arch = f"gfx{unwrap(e.props)['gfx_target_version']//1000}"
|
||||
if len(counter_events) == 0: return None
|
||||
ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), (durations, all_counters:={}))]})
|
||||
data.ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(data.ctxs), 0), (durations, all_counters:={}))]})
|
||||
run_number = {n:0 for n,_ in counter_events}
|
||||
for (k, tag),v in counter_events.items():
|
||||
# use the colored name if it exists
|
||||
name = unwrap(get_prg_uop(r)).src[0].arg.name if (r:=ref_map.get(pname:=prg_events[k].name)) is not None else pname
|
||||
name = unwrap(get_prg_uop(data, r)).src[0].arg.name if (r:=data.ref_map.get(pname:=prg_events[k].name)) is not None else pname
|
||||
run_number[k] += 1
|
||||
steps:list[dict] = []
|
||||
if (pmc:=v.get(ProfilePMCEvent)):
|
||||
steps.append(create_step("PMC", ("/prg-pmc", len(ctxs), len(steps)), pmc))
|
||||
steps.append(create_step("PMC", ("/prg-pmc", len(data.ctxs), len(steps)), pmc))
|
||||
all_counters[(name, run_number[k], pname)] = pmc[0]
|
||||
# to decode a SQTT trace, we need the raw stream, program binary and device properties
|
||||
if (sqtt:=v.get(ProfileSQTTEvent)):
|
||||
for e in sqtt:
|
||||
if e.itrace: steps.append(create_step(f"SE:{e.se} PKTS", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)), data=(e.blob, prg_events[k].lib,arch)))
|
||||
steps.append(create_step("OCC", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
|
||||
ctxs.append({"name":f"SQTT {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
|
||||
if e.itrace: steps.append(create_step(f"SE:{e.se} PKTS", (f"/prg-pkts-{e.se}", len(data.ctxs), len(steps)),
|
||||
data=(e.blob, prg_events[k].lib, arch)))
|
||||
steps.append(create_step("OCC", ("/prg-sqtt", len(data.ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
|
||||
data.ctxs.append({"name":f"SQTT {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
|
||||
|
||||
wave_colors = {"WMMA": "#1F7857", **{x:"#ffffc0" for x in ["VALU", "VINTERP"]}, "SALU": "#cef263", "SMEM": "#ffc0c0", "STORE": "#4fa3cc",
|
||||
**{x:"#b2b7c9" for x in ["VMEM", "SGMEM"]}, "LDS": "#9fb4a6", "IMMEDIATE": "#f3b44a", "BARRIER": "#d00000",
|
||||
|
|
@ -457,23 +464,25 @@ def device_sort_fn(k:str) -> tuple:
|
|||
dev_base = p[0] if len(p) < 2 or not p[1].isdigit() else f"{p[0]}:{p[1]}"
|
||||
return (is_memory, special.get(p[0], special['ALLDEVS']), dev_base, k)
|
||||
|
||||
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_sort_fn) -> bytes|None:
|
||||
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_sort_fn, data:VizData|None=None) -> bytes|None:
|
||||
if data is None: data = VizData(RewriteTrace([], [], {}))
|
||||
# start by getting the time diffs
|
||||
device_decoders:dict[str, Callable[[list[dict], list[ProfileEvent]], None]] = {}
|
||||
device_ts_diffs:dict[str, Decimal] = {}
|
||||
device_decoders:dict[str, Callable[[VizData, list[ProfileEvent]], None]] = {}
|
||||
for ev in profile:
|
||||
if isinstance(ev, ProfileDeviceEvent):
|
||||
device_ts_diffs[ev.device] = ev.tdiff
|
||||
if (d:=ev.device.split(":")[0]) == "AMD": device_decoders[d] = load_amd_counters
|
||||
if d == "NV": device_decoders[d] = load_nv_counters
|
||||
# load device specific counters
|
||||
for fxn in device_decoders.values(): fxn(ctxs, profile)
|
||||
for fxn in device_decoders.values(): fxn(data, profile)
|
||||
# map events per device
|
||||
dev_events:dict[str, list[tuple[int, int, float, DevEvent]]] = {}
|
||||
markers:list[ProfilePointEvent] = []
|
||||
ext_data:dict[str, Any] = {}
|
||||
start_ts:int|None = None
|
||||
end_ts:int|None = None
|
||||
for ts,en,e in flatten_events(profile):
|
||||
for ts,en,e in flatten_events(profile, device_ts_diffs):
|
||||
dev_events.setdefault(e.device,[]).append((st:=int(ts), et:=int(en), float(en-ts), e))
|
||||
if start_ts is None or st < start_ts: start_ts = st
|
||||
if end_ts is None or et > end_ts: end_ts = et
|
||||
|
|
@ -487,7 +496,7 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
|
|||
dtype_size:dict[str, int] = {}
|
||||
for k,v in dev_events.items():
|
||||
v.sort(key=lambda e:e[0])
|
||||
layout[k] = timeline_layout(v, start_ts, scache)
|
||||
layout[k] = timeline_layout(data, v, start_ts, scache)
|
||||
layout.update([graph_layout(k, v, start_ts, unwrap(end_ts), peaks, dtype_size, scache)])
|
||||
sorted_layout = sorted([k for k,v in layout.items() if v is not None], key=sort_fn)
|
||||
ret = [b"".join([struct.pack("<B", len(k)), k.encode(), unwrap(layout[k])]) for k in sorted_layout]
|
||||
|
|
@ -498,16 +507,16 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
|
|||
|
||||
# ** PMA counters
|
||||
|
||||
def load_nv_counters(ctxs:list[dict], profile:list) -> None:
|
||||
def load_nv_counters(data:VizData, profile:list) -> None:
|
||||
steps:list[dict] = []
|
||||
sm_version = {e.device:e.props.get("sm_version", 0x800) for e in profile if isinstance(e, ProfileDeviceEvent) and e.props is not None}
|
||||
run_number:dict[str, int] = {}
|
||||
for e in profile:
|
||||
if type(e).__name__ == "ProfilePMAEvent":
|
||||
run_number[e.kern] = run_num = run_number.get(e.kern, 0)+1
|
||||
steps.append(create_step(f"PMA {e.kern}"+(f"n{run_num}" if run_num>1 else ""), ("/prg-pma-pkts", len(ctxs), len(steps)),
|
||||
steps.append(create_step(f"PMA {e.kern}"+(f"n{run_num}" if run_num>1 else ""), ("/prg-pma-pkts", len(data.ctxs), len(steps)),
|
||||
data=(e.blob, sm_version[e.device])))
|
||||
if steps: ctxs.append({"name":"All Counters", "steps":steps})
|
||||
if steps: data.ctxs.append({"name":"All Counters", "steps":steps})
|
||||
|
||||
def pma_timeline(blob:bytes, sm_version:int) -> list[ProfileEvent]:
|
||||
from extra.nv_pma.decode import decode, decode_tpc_id
|
||||
|
|
@ -586,10 +595,10 @@ def amdgpu_cfg(lib:bytes, target:str) -> dict:
|
|||
from tinygrad.renderer.amd.dsl import Reg
|
||||
for pc, inst in pc_table.items():
|
||||
pc_tokens[pc] = tokens = []
|
||||
for name, field in inst._fields:
|
||||
for name, f in inst._fields:
|
||||
if isinstance(val:=getattr(inst, name), Reg): tokens.append({"st":val.fmt(), "keys":[f"r{val.offset+i}" for i in range(val.sz)], "kind":1})
|
||||
elif name in {"op","opx","opy"}: tokens.append({"st":(op_name:=val.name.lower()), "keys":[op_name], "kind":0})
|
||||
elif name != "encoding" and val != field.default: tokens.append({"st":(s:=repr(val)), "keys":[s], "kind":1})
|
||||
elif name != "encoding" and val != f.default: tokens.append({"st":(s:=repr(val)), "keys":[s], "kind":1})
|
||||
# show a smaller view for repeated instructions in the graph
|
||||
lines:list[str] = []
|
||||
disasm = {pc:str(inst) for pc,inst in pc_table.items()}
|
||||
|
|
@ -616,12 +625,12 @@ def amdgpu_cfg(lib:bytes, target:str) -> dict:
|
|||
|
||||
# ** Main render function to get the complete details about a trace event
|
||||
|
||||
def get_render(query:str) -> dict:
|
||||
def get_render(viz_data:VizData, query:str) -> dict:
|
||||
url = urlparse(query)
|
||||
i, j, fmt = get_int(qs:=parse_qs(url.query), "ctx"), get_int(qs, "step"), url.path.lstrip("/")
|
||||
data = ctxs[i]["steps"][j]["data"]
|
||||
if fmt == "graph-rewrites": return {"value":get_full_rewrite(trace.rewrites[i][j]), "content_type":"text/event-stream"}
|
||||
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(_reconstruct(trace.rewrites[i][j-1].sink).src[2].src)), "lang":"txt"}
|
||||
data = viz_data.ctxs[i]["steps"][j]["data"]
|
||||
if fmt == "graph-rewrites": return {"value":get_full_rewrite(viz_data, viz_data.trace.rewrites[i][j]), "content_type":"text/event-stream"}
|
||||
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(_reconstruct(viz_data, viz_data.trace.rewrites[i][j-1].sink).src[2].src))}
|
||||
if fmt == "code": return {"src":data, "lang":"cpp"}
|
||||
if fmt == "asm":
|
||||
ret:dict = {}
|
||||
|
|
@ -643,13 +652,13 @@ def get_render(query:str) -> dict:
|
|||
if fmt.startswith("prg-pkts"):
|
||||
ret = {}
|
||||
with soft_err(lambda err:ret.update(err)):
|
||||
if (events:=get_profile(list(itertools.islice(sqtt_timeline(*data), getenv("MAX_SQTT_PKTS", 50_000))), sort_fn=row_tuple)):
|
||||
if (events:=get_profile(list(itertools.islice(sqtt_timeline(*data), getenv("MAX_SQTT_PKTS", 50_000))), sort_fn=row_tuple, data=viz_data)):
|
||||
ret = {"value":events, "content_type":"application/octet-stream"}
|
||||
else: ret = {"src":"No SQTT trace on this SE."}
|
||||
return ret
|
||||
if fmt == "prg-sqtt":
|
||||
ret = {}
|
||||
if len((steps:=ctxs[i]["steps"])[j+1:]) == 0:
|
||||
if len((steps:=viz_data.ctxs[i]["steps"])[j+1:]) == 0:
|
||||
with soft_err(lambda err: ret.update(err)):
|
||||
cu_events, units, wave_insts = unpack_sqtt(*data)
|
||||
for cu in sorted(cu_events, key=row_tuple):
|
||||
|
|
@ -658,7 +667,7 @@ def get_render(query:str) -> dict:
|
|||
for k in sorted(wave_insts.get(cu, []), key=row_tuple):
|
||||
steps.append(create_step(k.replace(cu, ""), ("/sqtt-insts", i, len(steps)), loc=(data:=wave_insts[cu][k])["loc"], depth=2, data=data))
|
||||
return {**ret, "steps":[{k:v for k,v in s.items() if k != "data"} for s in steps[j+1:]]}
|
||||
if fmt == "cu-sqtt": return {"value":get_profile(data, sort_fn=row_tuple), "content_type":"application/octet-stream"}
|
||||
if fmt == "cu-sqtt": return {"value":get_profile(data, sort_fn=row_tuple, data=viz_data), "content_type":"application/octet-stream"}
|
||||
if fmt == "sqtt-insts":
|
||||
columns = ["PC", "Instruction", "Hits", "Cycles", "Stall", "Type"]
|
||||
inst_columns = ["N", "Clk", "Idle", "Dur", "Stall"]
|
||||
|
|
@ -685,11 +694,11 @@ def get_render(query:str) -> dict:
|
|||
prev_instr = max(prev_instr, e.time + e.dur)
|
||||
summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu},
|
||||
{"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}]
|
||||
return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":ref_map.get(data["prg"].name)}
|
||||
return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":viz_data.ref_map.get(data["prg"].name)}
|
||||
if fmt == "prg-pma-pkts":
|
||||
ret = {}
|
||||
with soft_err(lambda err:ret.update(err)):
|
||||
if (events:=get_profile(pma_timeline(*data), sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"}
|
||||
if (events:=get_profile(pma_timeline(*data), row_tuple, data=viz_data)): ret = {"value":events, "content_type":"application/octet-stream"}
|
||||
else: ret = {"src":"No PMA samples found."}
|
||||
return ret
|
||||
return data
|
||||
|
|
@ -712,11 +721,11 @@ class Handler(HTTPRequestHandler):
|
|||
except FileNotFoundError: status_code = 404
|
||||
|
||||
elif url.path == "/ctxs":
|
||||
lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in ctxs]
|
||||
lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in data.ctxs]
|
||||
ret, content_type = json.dumps(lst).encode(), "application/json"
|
||||
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
|
||||
else:
|
||||
if not (render_src:=get_render(self.path)): status_code = 404
|
||||
if not (render_src:=get_render(data, self.path)): status_code = 404
|
||||
else:
|
||||
if "content_type" in render_src: ret, content_type = render_src["value"], render_src["content_type"]
|
||||
else: ret, content_type = json.dumps(render_src).encode(), "application/json"
|
||||
|
|
@ -754,8 +763,9 @@ if __name__ == "__main__":
|
|||
st = time.perf_counter()
|
||||
print("*** viz is starting")
|
||||
|
||||
ctxs = get_rewrites(trace:=load_pickle(args.rewrites_path, default=RewriteTrace([], [], {})))
|
||||
profile_ret = get_profile(load_pickle(args.profile_path, default=[]))
|
||||
data = VizData(load_pickle(args.rewrites_path, default=RewriteTrace([], [], {})))
|
||||
load_rewrites(data)
|
||||
profile_ret = get_profile(load_pickle(args.profile_path, default=[]), data=data)
|
||||
|
||||
server = TCPServerWithReuse(('', PORT), Handler)
|
||||
reloader_thread = threading.Thread(target=reloader)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue