mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
viz: color wmma, one color map for cli and web (#15519)
* viz: color wmma, one color map for cli and web * op_type * like uops * mypy cli
This commit is contained in:
parent
0c3e438229
commit
36a925e2a2
4 changed files with 22 additions and 20 deletions
|
|
@ -72,9 +72,6 @@ def main(args) -> None:
|
|||
if "SQTT" in args.source:
|
||||
# modern terminals support 24-bit color
|
||||
def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m"
|
||||
WAVE_COLORS = ((('VALU', 'VINTERP'), '#ffffc0'), (('SALU',), '#cef263'), (('VMEM',), '#b2b7c9'), (('LOAD', 'SMEM'), '#ffc0c0'),
|
||||
(('STORE',), '#4fa3cc'), (('IMMEDIATE',), '#f3b44a'), (('BARRIER',), '#d00000'), (('LDS',), '#9fb4a6'), (('JUMP',), '#ffb703'),
|
||||
(('JUMP_NO',), '#fb8500'), (('MESSAGE',), '#90dbf4'), (('WAVERDY',), '#1a2a2a'))
|
||||
print(f"{'Clk':<12} {'Unit':<20} {'Op':<22} {'Dur':<4} {'Delay':<4} {'Info'}")
|
||||
print("-" * 100)
|
||||
pc_map:dict[int, str] = {}
|
||||
|
|
@ -87,9 +84,9 @@ def main(args) -> None:
|
|||
if inst_st is None: inst_st = int(e.st)
|
||||
assert isinstance(e.name, TracingKey)
|
||||
op_name, info = e.name.display_name, e.name.ret or ""
|
||||
color = next((c for p, c in WAVE_COLORS if any(x in op_name for x in p)), None)
|
||||
color = next((v for k,v in viz.wave_colors.items() if k in op_name), None)
|
||||
op_str = hex_colored(op_name, color) if color and not args.no_color else op_name
|
||||
phase, delay = None, ""
|
||||
phase, delay = None, 0
|
||||
idx = next(pkt_idxs.setdefault(e.device, itertools.count()))
|
||||
if e.device.startswith("WAVE") or e.device == "OTHER":
|
||||
inst = f"0x{(pc:=int(info.replace('PC:', ''))):05x} {pc_map[pc]}" if info else f"{'':7} {op_name}"
|
||||
|
|
@ -100,7 +97,7 @@ def main(args) -> None:
|
|||
phase, delay = "EXEC", int(e.st) - dispatch_st
|
||||
if inst and phase: info = f"{phase:<8} {inst}"
|
||||
unit = e.device.replace(" ", "-")
|
||||
print(f"{int(e.st)-inst_st:<12} {unit:<20} {op_str}{' '*(22-ansilen(op_str))} {int(unwrap(e.en)-e.st):<4} {str(delay):<4} {info}")
|
||||
print(f"{int(e.st)-inst_st:<12} {unit:<20} {op_str}{' '*(22-ansilen(op_str))} {int(unwrap(e.en)-e.st):<4} {str(delay or ''):<4} {info}")
|
||||
return None
|
||||
|
||||
# ** Profiler printer
|
||||
|
|
|
|||
|
|
@ -304,7 +304,7 @@ class ProfileRangeEvent(ProfileEvent): device:str; name:str|TracingKey; st:decim
|
|||
|
||||
@dataclass(frozen=True)
|
||||
class ProfilePointEvent(ProfileEvent):
|
||||
device:str; name:str; key:Any; arg:dict=field(default_factory=dict); ts:decimal.Decimal=field(default_factory=perf_counter_us) # noqa: E702
|
||||
device:str; name:str; key:Any; arg:Any=field(default_factory=dict); ts:decimal.Decimal=field(default_factory=perf_counter_us) # noqa: E702
|
||||
|
||||
cpu_events:list[ProfileEvent] = []
|
||||
@contextlib.contextmanager
|
||||
|
|
|
|||
|
|
@ -199,12 +199,8 @@ function formatCycles(cycles) {
|
|||
|
||||
const formatUnit = (d, unit="") => d3.format(".3~s")(d)+unit;
|
||||
|
||||
const WAVE_COLORS = {VALU:"#ffffc0", SALU:"#cef263", LOAD:"#ffc0c0", STORE:"#4fa3cc", IMMEDIATE:"#f3b44a", BARRIER:"#d00000", JUMP:"#ffb703",
|
||||
JUMP_NO:"#fb8500", MESSAGE:"#90dbf4", VMEM:"#b2b7c9", LDS:"#9fb4a6", WAVERDY:"#1a2a2a"};
|
||||
const waveColor = (op) => {
|
||||
const cat = op.includes("VALU") || op === "VINTERP" ? "VALU" : op.includes("SALU") ? "SALU" : op.includes("VMEM") ? "VMEM"
|
||||
: op.includes("LOAD") || op === "SMEM" ? "LOAD" : op.includes("STORE") ? "STORE" : op;
|
||||
let ret = WAVE_COLORS[cat] ?? "#ffffff";
|
||||
let ret = data.waveColors.find(([pattern]) => op.includes(pattern))?.[1] ?? "#ffffff";
|
||||
if (op.includes("OTHER_") || op.includes("_ALT")) { ret = darkenHex(ret, 75) }
|
||||
if (op.includes("LDS_")) { ret = darkenHex(ret, 25) }
|
||||
return ret
|
||||
|
|
@ -401,6 +397,7 @@ async function renderProfiler(path, opts) {
|
|||
const dur = u32(), tracePeak = u64(), indexLen = u32(), layoutsLen = u32(); data.dur = dur;
|
||||
const textDecoder = new TextDecoder("utf-8");
|
||||
const { strings, dtypeSize, markers, ...extData } = JSON.parse(textDecoder.decode(new Uint8Array(buf, offset, indexLen))); offset += indexLen;
|
||||
for (const [k,v] of Object.entries(extData)) data[k] = v;
|
||||
// place devices on the y axis and set vertical positions
|
||||
const [tickSize, padding, baseOffset] = [5, 8, markers.length ? 14 : 0];
|
||||
const secondaryTick = opts.unit == "clk" ? timeAtCycle : null;
|
||||
|
|
@ -574,7 +571,7 @@ async function renderProfiler(path, opts) {
|
|||
}
|
||||
}
|
||||
for (const m of markers) m.label = m.name.split(/(\s+)/).map(st => ({ st, color:m.color, width:ctx.measureText(st).width }));
|
||||
if (extData.pcMap != null) data.pcMap = extData.pcMap; setFocus(focusedShape);
|
||||
if (data.pcMap != null) setFocus(focusedShape);
|
||||
// secondary axis mapping
|
||||
let instRange = null;
|
||||
for (const [k, { shapes }] of data.tracks) if (!k.includes("Clock") && path.includes("pkts")) {
|
||||
|
|
|
|||
|
|
@ -338,6 +338,10 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None:
|
|||
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
|
||||
ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
|
||||
|
||||
wave_colors = {"WMMA": "#1F7857", **{x:"#ffffc0" for x in ["VALU", "VINTERP"]}, "SALU": "#cef263", "SMEM": "#ffc0c0", "STORE": "#4fa3cc",
|
||||
**{x:"#b2b7c9" for x in ["VMEM", "SGMEM"]}, "LDS": "#9fb4a6", "IMMEDIATE": "#f3b44a", "BARRIER": "#d00000",
|
||||
"JUMP_NO": "#fb8500", "JUMP": "#ffb703", "WAVERDY": "#1a2a2a"}
|
||||
|
||||
def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent, None, None]:
|
||||
from tinygrad.renderer.amd.sqtt import (map_insts, InstructionInfo, PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC,
|
||||
ALUEXEC, INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4, CDNA_INST, InstOpCDNA,
|
||||
|
|
@ -346,7 +350,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent,
|
|||
row_ends:dict[str, Decimal] = {}
|
||||
row_counts:dict[str, itertools.count] = {}
|
||||
curr_barrier:dict[int, ProfileRangeEvent] = {}
|
||||
exec_pending:dict[str, list[tuple[str, int]]] = {}
|
||||
exec_pending:dict[str, list[tuple[str, str]]] = {}
|
||||
is_cdna = target.startswith("gfx9")
|
||||
dispatch_to_exec = {"WMMA":"VALU", "VALU":"VALU", "VALU1":"VALU", "VALUT":"VALU", "VALUB":"VALU", "VALUINST":"VALU", "VINTERP":"VALU",
|
||||
"SGMEM":"VMEM", "FLAT":"VMEM", "LDS":"LDS", "SALU":"SALU", "SMEM":"SALU", "VMEM":"VMEM"}
|
||||
|
|
@ -358,19 +362,22 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent,
|
|||
# exec links to dispatch, dispatch links to PC
|
||||
link = f"PC:{info.pc}" if info else None
|
||||
if isinstance(p, (ALUEXEC, VMEMEXEC)):
|
||||
dispatch_id, duration = exec_pending[name].pop(0)
|
||||
# for execs, timestmap is the completion time
|
||||
dispatch_id, op_type = exec_pending[name].pop(0)
|
||||
# get the number of cycles from the op type
|
||||
duration = int(dur_match.group(1)) if (dur_match:=re.match(r".*_(\d+)$", op_type)) else 1
|
||||
# for execs, timestamp is the completion time
|
||||
start_time, end_time = p._time-duration, p._time
|
||||
link = f"LINK:{dispatch_id}"
|
||||
# add wmma in the exec name for coloring
|
||||
if op_type.startswith("WMMA"): name += "_WMMA"
|
||||
# queue inst dispatches
|
||||
idx = next(row_counts.setdefault(row, itertools.count(0)))
|
||||
if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.replace("OTHER_", "").split("_")[0])) is not None:
|
||||
if name.startswith("OTHER_"): exec_type = f"{exec_type}_ALT"
|
||||
# get the number of cycles from the op type
|
||||
duration = int(m.group(1)) if (m:=re.match(r".*_(\d+)$", name)) else 1
|
||||
# detect rdna3 wmma from the asm, only rdna4 has an op type for it
|
||||
if isinstance(p, VALUINST) and (asm:=getattr(unwrap(info).inst, "op_name", "")).startswith("V_WMMA"): duration = 16 if 'IU4' in asm else 32
|
||||
exec_pending.setdefault(exec_type, []).append((f"{row}-{idx}", duration))
|
||||
if isinstance(p, VALUINST) and (asm:=getattr(unwrap(info).inst, "op_name", "")).startswith("V_WMMA"):
|
||||
name = f"WMMA_{16 if 'IU4' in asm else 32}"
|
||||
exec_pending.setdefault(exec_type, []).append((f"{row}-{idx}", name))
|
||||
# construct and yield the event for this packet
|
||||
if row not in row_ends: yield ProfilePointEvent(row, "JSON", "pcMap", pc_map, ts=Decimal(0))
|
||||
yield (e:=ProfileRangeEvent(row, TracingKey(name, ret=link), Decimal(start_time), Decimal(end_time)))
|
||||
|
|
@ -384,6 +391,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent,
|
|||
if name == "BARRIER": curr_barrier[wave] = e
|
||||
NS_PER_TICK = 10 # 100MHz
|
||||
prev_pair:tuple[int, int]|None = None # (shader, realtime)
|
||||
yield ProfilePointEvent("", "JSON", "waveColors", list(wave_colors.items()), ts=Decimal(0))
|
||||
for p, info in map_insts(data, lib, target):
|
||||
if isinstance(p, (TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4)) and p.is_marker:
|
||||
pair = (p._time, p.delta)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue