mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
viz/sqtt: move amd decoder to extra, don't import from ops_amd (#15969)
* don't import from ops_amd * start * cleanup
This commit is contained in:
parent
7787f76dcc
commit
b63e0a5f74
2 changed files with 83 additions and 80 deletions
|
|
@ -1,10 +1,13 @@
|
|||
#!/usr/bin/env python3
|
||||
import ctypes, pathlib, argparse, pickle, dataclasses, threading
|
||||
import ctypes, pathlib, argparse, pickle, dataclasses, threading, itertools
|
||||
from decimal import Decimal
|
||||
from typing import Generator
|
||||
from tinygrad.helpers import temp, unwrap, DEBUG
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
|
||||
from tinygrad.runtime.autogen import rocprof
|
||||
from tinygrad.renderer.amd.dsl import Inst
|
||||
from tinygrad.helpers import ProfileEvent, ProfileRangeEvent, ProfilePointEvent
|
||||
from tinygrad.device import ProfileProgramEvent
|
||||
from test.amd.disasm import disasm
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
|
@ -126,6 +129,71 @@ def decode(sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, Inst]])
|
|||
raise exc
|
||||
return ROCParseCtx
|
||||
|
||||
def unpack_occ(viz_data, i:int, j:int, key:tuple[str, int], data:list, p:ProfileProgramEvent, target:str) -> dict:
|
||||
from tinygrad.viz.serve import amd_decode, create_step, row_tuple
|
||||
steps = viz_data.ctxs[i]["steps"]
|
||||
if len(steps[j+1:]) > 0: return {"steps":[{k:v for k,v in s.items() if k != "data"} for s in steps[j+1:]]}
|
||||
base = unwrap(p.base)
|
||||
disasm:dict[int, Inst] = {addr+base:inst for addr,inst in amd_decode(unwrap(p.lib), target).items()}
|
||||
rctx = decode(data, {p.tag:disasm})
|
||||
cu_events:dict[str, list[ProfileEvent]] = {}
|
||||
# ** inst traces
|
||||
wave_insts:dict[str, dict[str, dict]] = {}
|
||||
inst_units:dict[str, itertools.count] = {}
|
||||
for w in rctx.inst_execs.get(key, []):
|
||||
if (u:=w.wave_loc) not in inst_units: inst_units[u] = itertools.count(0)
|
||||
n = next(inst_units[u])
|
||||
if (events:=cu_events.get(w.cu_loc)) is None: cu_events[w.cu_loc] = events = []
|
||||
events.append(ProfileRangeEvent(f"SIMD:{w.simd}", loc:=f"INST WAVE:{w.wave_id} N:{n}", Decimal(w.begin_time), Decimal(w.end_time)))
|
||||
wave_insts.setdefault(w.cu_loc, {})[f"{u} N:{n}"] = {"wave":w, "disasm":disasm, "prg":p, "run_number":n, "loc":loc}
|
||||
# ** occ traces (only WAVESTART/WAVEEND)
|
||||
units:dict[str, itertools.count] = {}
|
||||
wave_start:dict[str, int] = {}
|
||||
for occ in rctx.occ_events.get(key, []):
|
||||
if (u:=occ.wave_loc) not in units: units[u] = itertools.count(0)
|
||||
if u in inst_units: continue
|
||||
if occ.start: wave_start[u] = occ.time
|
||||
else:
|
||||
if (events:=cu_events.get(occ.cu_loc)) is None: cu_events[occ.cu_loc] = events = []
|
||||
events.append(ProfileRangeEvent(f"SIMD:{occ.simd}", f"OCC WAVE:{occ.wave_id} N:{next(units[u])}", Decimal(wave_start.pop(u)),Decimal(occ.time)))
|
||||
# ** split graph by CU
|
||||
for cu in sorted(cu_events, key=row_tuple):
|
||||
steps.append(create_step(f"{cu} {len(cu_events[cu])}", ("/cu-sqtt", i, len(steps)), depth=1,
|
||||
data=[ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+cu_events[cu]))
|
||||
for k in sorted(wave_insts.get(cu, []), key=row_tuple):
|
||||
wd = wave_insts[cu][k]
|
||||
steps.append(create_step(k.replace(cu, ""), ("/amd-sqtt-insts", i, len(steps)), loc=wd["loc"], depth=2,
|
||||
data={"fxn":unpack_insts, "args":(wd,)}))
|
||||
return {"steps":[{k:v for k,v in s.items() if k != "data"} for s in steps[j+1:]]}
|
||||
|
||||
def unpack_insts(viz_data, i:int, j:int, data:dict) -> dict:
|
||||
columns = ["PC", "Instruction", "Hits", "Cycles", "Stall", "Type"]
|
||||
inst_columns = ["N", "Clk", "Idle", "Dur", "Stall"]
|
||||
# Idle: The total time gap between the completion of previous instruction and the beginning of the current instruction.
|
||||
# The idle time can be caused by:
|
||||
# * Arbiter loss
|
||||
# * Source or destination register dependency
|
||||
# * Instruction cache miss
|
||||
# Stall: The total number of cycles the hardware pipe couldn't issue an instruction.
|
||||
# Duration: Total latency in cycles, defined as "Stall time + Issue time" for gfx9 or "Stall time + Execute time" for gfx10+.
|
||||
prev_instr = (w:=data["wave"]).begin_time
|
||||
pc_to_inst = data["disasm"]
|
||||
start_pc = None
|
||||
rows:dict[int, dict] = {}
|
||||
for pc, inst in pc_to_inst.items():
|
||||
if start_pc is None: start_pc = pc
|
||||
rows[pc] = {"pc":pc-start_pc, "inst":str(inst), "hit_count":0, "dur":0, "stall":0, "type":"", "hits":{"cols":inst_columns, "rows":[]}}
|
||||
for e in w.unpack_insts():
|
||||
if not (inst:=rows[e.pc]).get("type"): inst["type"] = str(e.typ).split("_")[-1]
|
||||
inst["hit_count"] += 1
|
||||
inst["dur"] += e.dur
|
||||
inst["stall"] += e.stall
|
||||
inst["hits"]["rows"].append((inst["hit_count"]-1, e.time, max(0, e.time-prev_instr), e.dur, e.stall))
|
||||
prev_instr = max(prev_instr, e.time + e.dur)
|
||||
summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu},
|
||||
{"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}]
|
||||
return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":viz_data.ref_map.get(data["prg"].name)}
|
||||
|
||||
def print_data(data:dict) -> None:
|
||||
from tabulate import tabulate
|
||||
# plaintext
|
||||
|
|
|
|||
|
|
@ -320,14 +320,14 @@ def unpack_pmc(e) -> dict:
|
|||
|
||||
# ** on startup, list all the performance counter traces
|
||||
|
||||
def load_amd_counters(data:VizData, profile:list[ProfileEvent]) -> None:
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
|
||||
def load_amd_counters(data:VizData, profile:list) -> None:
|
||||
counter_events:dict[tuple[int, int], dict] = {}
|
||||
durations:dict[str, list[float]] = {}
|
||||
prg_events:dict[int, ProfileProgramEvent] = {}
|
||||
arch = ""
|
||||
for e in profile:
|
||||
if isinstance(e, (ProfilePMCEvent, ProfileSQTTEvent)): counter_events.setdefault((e.kern, e.exec_tag), {}).setdefault(type(e), []).append(e)
|
||||
if type(e).__name__ in {"ProfilePMCEvent", "ProfileSQTTEvent"}:
|
||||
counter_events.setdefault((e.kern, e.exec_tag), {}).setdefault(type(e).__name__, []).append(e)
|
||||
if isinstance(e, ProfileRangeEvent) and e.device.startswith("AMD") and e.en is not None:
|
||||
durations.setdefault(str(e.name), []).append(float(e.en-e.st))
|
||||
if isinstance(e, ProfileProgramEvent) and e.tag is not None: prg_events[e.tag] = e
|
||||
|
|
@ -340,15 +340,18 @@ def load_amd_counters(data:VizData, profile:list[ProfileEvent]) -> None:
|
|||
name = data.ctxs[r]["prg"].src[0].arg.name if (r:=data.ref_map.get(pname:=prg_events[k].name)) is not None else pname
|
||||
run_number[k] += 1
|
||||
steps:list[dict] = []
|
||||
if (pmc:=v.get(ProfilePMCEvent)):
|
||||
if (pmc:=v.get("ProfilePMCEvent")):
|
||||
steps.append(create_step("PMC", ("/prg-pmc", len(data.ctxs), len(steps)), pmc[0]))
|
||||
all_counters[(name, run_number[k], pname)] = pmc[0]
|
||||
# to decode a SQTT trace, we need the raw stream, program binary and device properties
|
||||
if (sqtt:=v.get(ProfileSQTTEvent)):
|
||||
if (sqtt:=v.get("ProfileSQTTEvent")):
|
||||
for e in sqtt:
|
||||
if e.itrace: steps.append(create_step(f"SE:{e.se} PKTS", (f"/prg-pkts-{e.se}", len(data.ctxs), len(steps)),
|
||||
data=(e.blob, prg_events[k].lib, arch)))
|
||||
steps.append(create_step("OCC", ("/prg-sqtt", len(data.ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
|
||||
if e.itrace: steps.append(create_step(f"SE:{e.se} PKTS", (f"/sqtt-{e.se}",len(data.ctxs),len(steps)), data=(e.blob,prg_events[k].lib,arch)))
|
||||
try:
|
||||
from extra.sqtt.roc import unpack_occ
|
||||
steps.append(create_step("OCC", ("/amd-sqtt-occ", len(data.ctxs), len(steps)),
|
||||
data={"fxn":unpack_occ, "args":((k, tag), sqtt, prg_events[k], arch)}))
|
||||
except Exception: pass
|
||||
data.ctxs.append({"name":f"SQTT {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
|
||||
|
||||
wave_colors = {"WMMA": "#1F7857", **{x:"#ffffc0" for x in ["VALU", "VINTERP"]}, "SALU": "#cef263", "SMEM": "#ffc0c0", "STORE": "#4fa3cc",
|
||||
|
|
@ -430,38 +433,6 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent,
|
|||
else:
|
||||
yield from add(name, p)
|
||||
|
||||
# ** SQTT OCC only unpacks wave start, end time and SIMD location
|
||||
|
||||
def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent,
|
||||
target:str) -> tuple[dict[str, list[ProfileEvent]], list[str], dict[str, dict[str, dict]]]:
|
||||
# * init decoder
|
||||
from extra.sqtt.roc import decode
|
||||
base = unwrap(p.base)
|
||||
addr_table = amd_decode(unwrap(p.lib), target)
|
||||
disasm:dict[int, Inst] = {addr+base:inst for addr, inst in addr_table.items()}
|
||||
rctx = decode(data, {p.tag:disasm})
|
||||
cu_events:dict[str, list[ProfileEvent]] = {}
|
||||
# * INST waves
|
||||
wave_insts:dict[str, dict[str, dict]] = {}
|
||||
inst_units:dict[str, itertools.count] = {}
|
||||
for w in rctx.inst_execs.get(key, []):
|
||||
if (u:=w.wave_loc) not in inst_units: inst_units[u] = itertools.count(0)
|
||||
n = next(inst_units[u])
|
||||
if (events:=cu_events.get(w.cu_loc)) is None: cu_events[w.cu_loc] = events = []
|
||||
events.append(ProfileRangeEvent(f"SIMD:{w.simd}", loc:=f"INST WAVE:{w.wave_id} N:{n}", Decimal(w.begin_time), Decimal(w.end_time)))
|
||||
wave_insts.setdefault(w.cu_loc, {})[f"{u} N:{n}"] = {"wave":w, "disasm":disasm, "prg":p, "run_number":n, "loc":loc}
|
||||
# * OCC waves
|
||||
units:dict[str, itertools.count] = {}
|
||||
wave_start:dict[str, int] = {}
|
||||
for occ in rctx.occ_events.get(key, []):
|
||||
if (u:=occ.wave_loc) not in units: units[u] = itertools.count(0)
|
||||
if u in inst_units: continue
|
||||
if occ.start: wave_start[u] = occ.time
|
||||
else:
|
||||
if (events:=cu_events.get(occ.cu_loc)) is None: cu_events[occ.cu_loc] = events = []
|
||||
events.append(ProfileRangeEvent(f"SIMD:{occ.simd}", f"OCC WAVE:{occ.wave_id} N:{next(units[u])}", Decimal(wave_start.pop(u)),Decimal(occ.time)))
|
||||
return cu_events, list(units), wave_insts
|
||||
|
||||
def device_sort_fn(k:str) -> tuple:
|
||||
special = {"GC": 0, "USER": 1, "TINY": 2, "ALLDEVS":100, "DISK": 999}
|
||||
is_memory = k.endswith(" Memory")
|
||||
|
|
@ -653,52 +624,16 @@ def get_render(viz_data:VizData, query:str) -> dict:
|
|||
ret["cols"] = ["Kernel", "Duration", *ret["cols"]]
|
||||
return ret
|
||||
if fmt == "prg-pmc": return unpack_pmc(data)
|
||||
if fmt.startswith("prg-pkts"):
|
||||
if fmt.startswith("sqtt"):
|
||||
ret = {}
|
||||
with soft_err(lambda err:ret.update(err)):
|
||||
if (events:=get_profile(viz_data, list(itertools.islice(sqtt_timeline(*data), getenv("MAX_SQTT_PKTS", 50_000))), sort_fn=row_tuple)):
|
||||
ret = {"value":events, "content_type":"application/octet-stream"}
|
||||
else: ret = {"src":"No SQTT trace on this SE."}
|
||||
return ret
|
||||
if fmt == "prg-sqtt":
|
||||
ret = {}
|
||||
if len((steps:=viz_data.ctxs[i]["steps"])[j+1:]) == 0:
|
||||
with soft_err(lambda err: ret.update(err)):
|
||||
cu_events, units, wave_insts = unpack_sqtt(*data)
|
||||
for cu in sorted(cu_events, key=row_tuple):
|
||||
steps.append(create_step(f"{cu} {len(cu_events[cu])}", ("/cu-sqtt", i, len(steps)), depth=1,
|
||||
data=[ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+cu_events[cu]))
|
||||
for k in sorted(wave_insts.get(cu, []), key=row_tuple):
|
||||
steps.append(create_step(k.replace(cu, ""), ("/sqtt-insts", i, len(steps)), loc=(data:=wave_insts[cu][k])["loc"], depth=2, data=data))
|
||||
return {**ret, "steps":[{k:v for k,v in s.items() if k != "data"} for s in steps[j+1:]]}
|
||||
# viewers for the amd decoder in extra
|
||||
if fmt.startswith("amd-sqtt"): return data["fxn"](viz_data, i, j, *data["args"])
|
||||
if fmt == "cu-sqtt": return {"value":get_profile(viz_data, data, sort_fn=row_tuple), "content_type":"application/octet-stream"}
|
||||
if fmt == "sqtt-insts":
|
||||
columns = ["PC", "Instruction", "Hits", "Cycles", "Stall", "Type"]
|
||||
inst_columns = ["N", "Clk", "Idle", "Dur", "Stall"]
|
||||
# Idle: The total time gap between the completion of previous instruction and the beginning of the current instruction.
|
||||
# The idle time can be caused by:
|
||||
# * Arbiter loss
|
||||
# * Source or destination register dependency
|
||||
# * Instruction cache miss
|
||||
# Stall: The total number of cycles the hardware pipe couldn't issue an instruction.
|
||||
# Duration: Total latency in cycles, defined as "Stall time + Issue time" for gfx9 or "Stall time + Execute time" for gfx10+.
|
||||
prev_instr = (w:=data["wave"]).begin_time
|
||||
pc_to_inst = data["disasm"]
|
||||
start_pc = None
|
||||
rows:dict[int, dict] = {}
|
||||
for pc, inst in pc_to_inst.items():
|
||||
if start_pc is None: start_pc = pc
|
||||
rows[pc] = {"pc":pc-start_pc, "inst":str(inst), "hit_count":0, "dur":0, "stall":0, "type":"", "hits":{"cols":inst_columns, "rows":[]}}
|
||||
for e in w.unpack_insts():
|
||||
if not (inst:=rows[e.pc]).get("type"): inst["type"] = str(e.typ).split("_")[-1]
|
||||
inst["hit_count"] += 1
|
||||
inst["dur"] += e.dur
|
||||
inst["stall"] += e.stall
|
||||
inst["hits"]["rows"].append((inst["hit_count"]-1, e.time, max(0, e.time-prev_instr), e.dur, e.stall))
|
||||
prev_instr = max(prev_instr, e.time + e.dur)
|
||||
summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu},
|
||||
{"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}]
|
||||
return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":viz_data.ref_map.get(data["prg"].name)}
|
||||
if fmt == "prg-pma-pkts":
|
||||
ret = {}
|
||||
with soft_err(lambda err:ret.update(err)):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue