viz/sqtt: move amd decoder to extra, don't import from ops_amd (#15969)

* don't import from ops_amd

* start

* cleanup
This commit is contained in:
qazal 2026-04-29 18:49:15 +03:00 committed by GitHub
commit b63e0a5f74
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 83 additions and 80 deletions

View file

@ -1,10 +1,13 @@
#!/usr/bin/env python3
import ctypes, pathlib, argparse, pickle, dataclasses, threading
import ctypes, pathlib, argparse, pickle, dataclasses, threading, itertools
from decimal import Decimal
from typing import Generator
from tinygrad.helpers import temp, unwrap, DEBUG
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
from tinygrad.runtime.autogen import rocprof
from tinygrad.renderer.amd.dsl import Inst
from tinygrad.helpers import ProfileEvent, ProfileRangeEvent, ProfilePointEvent
from tinygrad.device import ProfileProgramEvent
from test.amd.disasm import disasm
@dataclasses.dataclass(frozen=True)
@ -126,6 +129,71 @@ def decode(sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, Inst]])
raise exc
return ROCParseCtx
def unpack_occ(viz_data, i:int, j:int, key:tuple[str, int], data:list, p:ProfileProgramEvent, target:str) -> dict:
from tinygrad.viz.serve import amd_decode, create_step, row_tuple
steps = viz_data.ctxs[i]["steps"]
if len(steps[j+1:]) > 0: return {"steps":[{k:v for k,v in s.items() if k != "data"} for s in steps[j+1:]]}
base = unwrap(p.base)
disasm:dict[int, Inst] = {addr+base:inst for addr,inst in amd_decode(unwrap(p.lib), target).items()}
rctx = decode(data, {p.tag:disasm})
cu_events:dict[str, list[ProfileEvent]] = {}
# ** inst traces
wave_insts:dict[str, dict[str, dict]] = {}
inst_units:dict[str, itertools.count] = {}
for w in rctx.inst_execs.get(key, []):
if (u:=w.wave_loc) not in inst_units: inst_units[u] = itertools.count(0)
n = next(inst_units[u])
if (events:=cu_events.get(w.cu_loc)) is None: cu_events[w.cu_loc] = events = []
events.append(ProfileRangeEvent(f"SIMD:{w.simd}", loc:=f"INST WAVE:{w.wave_id} N:{n}", Decimal(w.begin_time), Decimal(w.end_time)))
wave_insts.setdefault(w.cu_loc, {})[f"{u} N:{n}"] = {"wave":w, "disasm":disasm, "prg":p, "run_number":n, "loc":loc}
# ** occ traces (only WAVESTART/WAVEEND)
units:dict[str, itertools.count] = {}
wave_start:dict[str, int] = {}
for occ in rctx.occ_events.get(key, []):
if (u:=occ.wave_loc) not in units: units[u] = itertools.count(0)
if u in inst_units: continue
if occ.start: wave_start[u] = occ.time
else:
if (events:=cu_events.get(occ.cu_loc)) is None: cu_events[occ.cu_loc] = events = []
events.append(ProfileRangeEvent(f"SIMD:{occ.simd}", f"OCC WAVE:{occ.wave_id} N:{next(units[u])}", Decimal(wave_start.pop(u)),Decimal(occ.time)))
# ** split graph by CU
for cu in sorted(cu_events, key=row_tuple):
steps.append(create_step(f"{cu} {len(cu_events[cu])}", ("/cu-sqtt", i, len(steps)), depth=1,
data=[ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+cu_events[cu]))
for k in sorted(wave_insts.get(cu, []), key=row_tuple):
wd = wave_insts[cu][k]
steps.append(create_step(k.replace(cu, ""), ("/amd-sqtt-insts", i, len(steps)), loc=wd["loc"], depth=2,
data={"fxn":unpack_insts, "args":(wd,)}))
return {"steps":[{k:v for k,v in s.items() if k != "data"} for s in steps[j+1:]]}
def unpack_insts(viz_data, i:int, j:int, data:dict) -> dict:
columns = ["PC", "Instruction", "Hits", "Cycles", "Stall", "Type"]
inst_columns = ["N", "Clk", "Idle", "Dur", "Stall"]
# Idle: The total time gap between the completion of previous instruction and the beginning of the current instruction.
# The idle time can be caused by:
# * Arbiter loss
# * Source or destination register dependency
# * Instruction cache miss
# Stall: The total number of cycles the hardware pipe couldn't issue an instruction.
# Duration: Total latency in cycles, defined as "Stall time + Issue time" for gfx9 or "Stall time + Execute time" for gfx10+.
prev_instr = (w:=data["wave"]).begin_time
pc_to_inst = data["disasm"]
start_pc = None
rows:dict[int, dict] = {}
for pc, inst in pc_to_inst.items():
if start_pc is None: start_pc = pc
rows[pc] = {"pc":pc-start_pc, "inst":str(inst), "hit_count":0, "dur":0, "stall":0, "type":"", "hits":{"cols":inst_columns, "rows":[]}}
for e in w.unpack_insts():
if not (inst:=rows[e.pc]).get("type"): inst["type"] = str(e.typ).split("_")[-1]
inst["hit_count"] += 1
inst["dur"] += e.dur
inst["stall"] += e.stall
inst["hits"]["rows"].append((inst["hit_count"]-1, e.time, max(0, e.time-prev_instr), e.dur, e.stall))
prev_instr = max(prev_instr, e.time + e.dur)
summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu},
{"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}]
return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":viz_data.ref_map.get(data["prg"].name)}
def print_data(data:dict) -> None:
from tabulate import tabulate
# plaintext

View file

@ -320,14 +320,14 @@ def unpack_pmc(e) -> dict:
# ** on startup, list all the performance counter traces
def load_amd_counters(data:VizData, profile:list[ProfileEvent]) -> None:
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
def load_amd_counters(data:VizData, profile:list) -> None:
counter_events:dict[tuple[int, int], dict] = {}
durations:dict[str, list[float]] = {}
prg_events:dict[int, ProfileProgramEvent] = {}
arch = ""
for e in profile:
if isinstance(e, (ProfilePMCEvent, ProfileSQTTEvent)): counter_events.setdefault((e.kern, e.exec_tag), {}).setdefault(type(e), []).append(e)
if type(e).__name__ in {"ProfilePMCEvent", "ProfileSQTTEvent"}:
counter_events.setdefault((e.kern, e.exec_tag), {}).setdefault(type(e).__name__, []).append(e)
if isinstance(e, ProfileRangeEvent) and e.device.startswith("AMD") and e.en is not None:
durations.setdefault(str(e.name), []).append(float(e.en-e.st))
if isinstance(e, ProfileProgramEvent) and e.tag is not None: prg_events[e.tag] = e
@ -340,15 +340,18 @@ def load_amd_counters(data:VizData, profile:list[ProfileEvent]) -> None:
name = data.ctxs[r]["prg"].src[0].arg.name if (r:=data.ref_map.get(pname:=prg_events[k].name)) is not None else pname
run_number[k] += 1
steps:list[dict] = []
if (pmc:=v.get(ProfilePMCEvent)):
if (pmc:=v.get("ProfilePMCEvent")):
steps.append(create_step("PMC", ("/prg-pmc", len(data.ctxs), len(steps)), pmc[0]))
all_counters[(name, run_number[k], pname)] = pmc[0]
# to decode a SQTT trace, we need the raw stream, program binary and device properties
if (sqtt:=v.get(ProfileSQTTEvent)):
if (sqtt:=v.get("ProfileSQTTEvent")):
for e in sqtt:
if e.itrace: steps.append(create_step(f"SE:{e.se} PKTS", (f"/prg-pkts-{e.se}", len(data.ctxs), len(steps)),
data=(e.blob, prg_events[k].lib, arch)))
steps.append(create_step("OCC", ("/prg-sqtt", len(data.ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
if e.itrace: steps.append(create_step(f"SE:{e.se} PKTS", (f"/sqtt-{e.se}",len(data.ctxs),len(steps)), data=(e.blob,prg_events[k].lib,arch)))
try:
from extra.sqtt.roc import unpack_occ
steps.append(create_step("OCC", ("/amd-sqtt-occ", len(data.ctxs), len(steps)),
data={"fxn":unpack_occ, "args":((k, tag), sqtt, prg_events[k], arch)}))
except Exception: pass
data.ctxs.append({"name":f"SQTT {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
wave_colors = {"WMMA": "#1F7857", **{x:"#ffffc0" for x in ["VALU", "VINTERP"]}, "SALU": "#cef263", "SMEM": "#ffc0c0", "STORE": "#4fa3cc",
@ -430,38 +433,6 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent,
else:
yield from add(name, p)
# ** SQTT OCC only unpacks wave start, end time and SIMD location
def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent,
target:str) -> tuple[dict[str, list[ProfileEvent]], list[str], dict[str, dict[str, dict]]]:
# * init decoder
from extra.sqtt.roc import decode
base = unwrap(p.base)
addr_table = amd_decode(unwrap(p.lib), target)
disasm:dict[int, Inst] = {addr+base:inst for addr, inst in addr_table.items()}
rctx = decode(data, {p.tag:disasm})
cu_events:dict[str, list[ProfileEvent]] = {}
# * INST waves
wave_insts:dict[str, dict[str, dict]] = {}
inst_units:dict[str, itertools.count] = {}
for w in rctx.inst_execs.get(key, []):
if (u:=w.wave_loc) not in inst_units: inst_units[u] = itertools.count(0)
n = next(inst_units[u])
if (events:=cu_events.get(w.cu_loc)) is None: cu_events[w.cu_loc] = events = []
events.append(ProfileRangeEvent(f"SIMD:{w.simd}", loc:=f"INST WAVE:{w.wave_id} N:{n}", Decimal(w.begin_time), Decimal(w.end_time)))
wave_insts.setdefault(w.cu_loc, {})[f"{u} N:{n}"] = {"wave":w, "disasm":disasm, "prg":p, "run_number":n, "loc":loc}
# * OCC waves
units:dict[str, itertools.count] = {}
wave_start:dict[str, int] = {}
for occ in rctx.occ_events.get(key, []):
if (u:=occ.wave_loc) not in units: units[u] = itertools.count(0)
if u in inst_units: continue
if occ.start: wave_start[u] = occ.time
else:
if (events:=cu_events.get(occ.cu_loc)) is None: cu_events[occ.cu_loc] = events = []
events.append(ProfileRangeEvent(f"SIMD:{occ.simd}", f"OCC WAVE:{occ.wave_id} N:{next(units[u])}", Decimal(wave_start.pop(u)),Decimal(occ.time)))
return cu_events, list(units), wave_insts
def device_sort_fn(k:str) -> tuple:
special = {"GC": 0, "USER": 1, "TINY": 2, "ALLDEVS":100, "DISK": 999}
is_memory = k.endswith(" Memory")
@ -653,52 +624,16 @@ def get_render(viz_data:VizData, query:str) -> dict:
ret["cols"] = ["Kernel", "Duration", *ret["cols"]]
return ret
if fmt == "prg-pmc": return unpack_pmc(data)
if fmt.startswith("prg-pkts"):
if fmt.startswith("sqtt"):
ret = {}
with soft_err(lambda err:ret.update(err)):
if (events:=get_profile(viz_data, list(itertools.islice(sqtt_timeline(*data), getenv("MAX_SQTT_PKTS", 50_000))), sort_fn=row_tuple)):
ret = {"value":events, "content_type":"application/octet-stream"}
else: ret = {"src":"No SQTT trace on this SE."}
return ret
if fmt == "prg-sqtt":
ret = {}
if len((steps:=viz_data.ctxs[i]["steps"])[j+1:]) == 0:
with soft_err(lambda err: ret.update(err)):
cu_events, units, wave_insts = unpack_sqtt(*data)
for cu in sorted(cu_events, key=row_tuple):
steps.append(create_step(f"{cu} {len(cu_events[cu])}", ("/cu-sqtt", i, len(steps)), depth=1,
data=[ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+cu_events[cu]))
for k in sorted(wave_insts.get(cu, []), key=row_tuple):
steps.append(create_step(k.replace(cu, ""), ("/sqtt-insts", i, len(steps)), loc=(data:=wave_insts[cu][k])["loc"], depth=2, data=data))
return {**ret, "steps":[{k:v for k,v in s.items() if k != "data"} for s in steps[j+1:]]}
# viewers for the amd decoder in extra
if fmt.startswith("amd-sqtt"): return data["fxn"](viz_data, i, j, *data["args"])
if fmt == "cu-sqtt": return {"value":get_profile(viz_data, data, sort_fn=row_tuple), "content_type":"application/octet-stream"}
if fmt == "sqtt-insts":
columns = ["PC", "Instruction", "Hits", "Cycles", "Stall", "Type"]
inst_columns = ["N", "Clk", "Idle", "Dur", "Stall"]
# Idle: The total time gap between the completion of previous instruction and the beginning of the current instruction.
# The idle time can be caused by:
# * Arbiter loss
# * Source or destination register dependency
# * Instruction cache miss
# Stall: The total number of cycles the hardware pipe couldn't issue an instruction.
# Duration: Total latency in cycles, defined as "Stall time + Issue time" for gfx9 or "Stall time + Execute time" for gfx10+.
prev_instr = (w:=data["wave"]).begin_time
pc_to_inst = data["disasm"]
start_pc = None
rows:dict[int, dict] = {}
for pc, inst in pc_to_inst.items():
if start_pc is None: start_pc = pc
rows[pc] = {"pc":pc-start_pc, "inst":str(inst), "hit_count":0, "dur":0, "stall":0, "type":"", "hits":{"cols":inst_columns, "rows":[]}}
for e in w.unpack_insts():
if not (inst:=rows[e.pc]).get("type"): inst["type"] = str(e.typ).split("_")[-1]
inst["hit_count"] += 1
inst["dur"] += e.dur
inst["stall"] += e.stall
inst["hits"]["rows"].append((inst["hit_count"]-1, e.time, max(0, e.time-prev_instr), e.dur, e.stall))
prev_instr = max(prev_instr, e.time + e.dur)
summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu},
{"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}]
return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":viz_data.ref_map.get(data["prg"].name)}
if fmt == "prg-pma-pkts":
ret = {}
with soft_err(lambda err:ret.update(err)):