Compare commits

...

11 commits

Author SHA1 Message Date
George Hotz
f8ff531b68 AQL in mmapeak 2025-11-16 10:15:45 -08:00
George Hotz
c9254c32df skip flaky test 2025-11-16 09:58:52 -08:00
George Hotz
ddaaeb16de lil cleanup 2025-11-16 09:48:19 -08:00
George Hotz
5fba9ccb85 split them 2025-11-16 09:35:40 -08:00
George Hotz
2136c76fa8 that 2025-11-16 09:31:16 -08:00
George Hotz
f9f5fd2b41 more filter 2025-11-16 09:20:20 -08:00
George Hotz
dac087f2b7 improve parser 2025-11-16 09:08:42 -08:00
George Hotz
1129e4c5d5 parse print new 2025-11-16 09:04:41 -08:00
George Hotz
aaec1130a3 sep VIZ/PROFILE 2025-11-16 08:59:50 -08:00
George Hotz
24c7f38105 more minimal runner 2025-11-16 08:48:13 -08:00
George Hotz
22432917d3 more work parsing SQTT 2025-11-16 08:28:38 -08:00
9 changed files with 349 additions and 123 deletions

View file

@ -1,8 +1,10 @@
import pathlib import os, pathlib
# TODO: there is a timing bug without this
os.environ["AMD_AQL"] = "1"
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.runtime.ops_amd import AMDProgram, HIPCompiler from tinygrad.runtime.ops_amd import AMDProgram, HIPCompiler
import time
import os
NUM_WORKGROUPS = 96 NUM_WORKGROUPS = 96
WAVE_SIZE = 32 WAVE_SIZE = 32
@ -44,9 +46,9 @@ if __name__=="__main__":
raise RuntimeError("Error while initiating AMD device") raise RuntimeError("Error while initiating AMD device")
COMPILER = HIPCompiler(DEV.arch) COMPILER = HIPCompiler(DEV.arch)
if DEV.arch in {'gfx1100', 'gfx1103'}: if DEV.arch in {'gfx1100', 'gfx1103', 'gfx1151'}:
if DEV.arch == 'gfx1103': if DEV.arch == 'gfx1103': NUM_WORKGROUPS = 8
NUM_WORKGROUPS = 8 if DEV.arch == 'gfx1151': NUM_WORKGROUPS = 40
launchBenchmark("v_wmma_bf16_16x16x16_bf16", (7,8,15)) launchBenchmark("v_wmma_bf16_16x16x16_bf16", (7,8,15))
launchBenchmark("v_wmma_f16_16x16x16_f16", (7,8,15)) launchBenchmark("v_wmma_f16_16x16x16_f16", (7,8,15))
launchBenchmark("v_wmma_f32_16x16x16_bf16", (7,8,15)) launchBenchmark("v_wmma_f32_16x16x16_bf16", (7,8,15))

View file

@ -0,0 +1,99 @@
import os
os.environ["PYTHONPATH"] = "."
os.environ["SQTT"] = "1"
if "DEV" not in os.environ: os.environ["DEV"] = "AMD"
os.environ["PROFILE"] = "1"
os.environ["AMD_LLVM"] = "0"
from dataclasses import replace
import atexit, contextlib
from tinygrad.helpers import system, getenv
from tinygrad.runtime.ops_amd import AMDProgram
from extra.sqtt.roc import decode, WaveExec, ProfileSQTTEvent
from tinygrad.device import Device, ProfileDeviceEvent
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
def set_power(x): system(f"sudo /opt/rocm/bin/amd-smi set -l {x}")
@atexit.register
def reset_power(): set_power("auto")
set_power("stable_std")
dev = Device["AMD"]
@contextlib.contextmanager
def save_sqtt():
# clear the old traces
dev.profile_events.clear()
sqtt:dict[str, list[WaveExec]] = {}
yield sqtt
events = dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())]
rctx = decode(events)
assert len(rctx.inst_execs) > 0, "empty sqtt output"
sqtt.update(rctx.inst_execs)
for e in events:
if isinstance(e, ProfileSQTTEvent):
print(replace(e, blob=b''))
if e.se == 0:
parse_sqtt_print_packets(e.blob, filter=[0xf, 0x11, 0x12, 0x14] if getenv("FILTER", 1) else None)
template = """.text
.globl matmul
.p2align 8
.type matmul,@function
matmul:
INSTRUCTION
s_endpgm
.rodata
.p2align 6
.amdhsa_kernel matmul
.amdhsa_next_free_vgpr .amdgcn.next_free_vgpr
.amdhsa_next_free_sgpr .amdgcn.next_free_sgpr
.amdhsa_wavefront_size32 1
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.version:
- 1
- 0
amdhsa.kernels:
- .name: matmul
.symbol: matmul.kd
.kernarg_segment_size: 0
.group_segment_fixed_size: 0
.private_segment_fixed_size: 0
.kernarg_segment_align: 4
.wavefront_size: 32
.sgpr_count: 8
.vgpr_count: 32
.max_flat_workgroup_size: 1024
...
.end_amdgpu_metadata
"""
def run_asm(src):
NUM_WORKGROUPS = 1
WAVE_SIZE = 32
NUM_WAVES = 1
lib = dev.compiler.compile(template.replace("INSTRUCTION", '\n'.join(src)))
dev.compiler.disassemble(lib)
fxn = AMDProgram(dev, "matmul", lib)
fxn(global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True)
if __name__ == "__main__":
with save_sqtt() as sqtt:
run_asm([
#"v_rcp_f32 v1, v0"
"v_add_f32_e32 v1 v0 v0",
"v_add_f32_e32 v3 v2 v2",
"v_add_f32_e32 v5 v4 v4",
"v_add_f32_e32 v7 v6 v6",
#"v_add_f32_e32 v1 v0 v0",
#"v_add_f32_e32 v2 v1 v1",
#"s_nop 1"
]*1)

View file

@ -1,44 +1,48 @@
import pickle import pickle
from hexdump import hexdump
from extra.sqtt.roc import decode, ProfileSQTTEvent from extra.sqtt.roc import decode, ProfileSQTTEvent
from tinygrad.helpers import getenv
# Instruction packets (one per ISA op) # Instruction packets (one per ISA op)
# NOTE: these are bad guesses and may be wrong! feel free to update if you know better # NOTE: these are bad guesses and may be wrong! feel free to update if you know better
OPCODE_NAMES = { OPCODE_NAMES = {
# Small metadata / structural packets (NOT ISA op kinds) # ------------------------------------------------------------------------
0x01: "META_SMALL_ID", # 12-bit identifier / slot tag # 0x010x06: small “meta + maybe tiny delta” packets
0x02: "META_FLAG", # 1-byte flag/mode (CF/AF/8F/DF...) # ------------------------------------------------------------------------
0x03: "META_SUBEVENT_CODE", # 1-byte sub-event/classification code 0x01: "META_ID12_TS_SMALL", # 12-bit ID + 3-bit delta field
0x04: "META_BASE_INDEX_TAG", # 12-bit base index/tag (..D, 9D, 10D, 58D...) 0x02: "META_FLAG8_TS_SMALL", # 8-bit flag/mode + small delta
0x03: "META_SUBEVENT8_TS_SMALL", # 8-bit subevent/class + small delta
0x04: "META_BASE_INDEX12_TS", # 12-bit base index + small delta
0x05: "META_DESC24_TS_A", # 24-bit descriptor-ish + delta field
0x06: "META_DESC24_TS_B", # second flavour, 24-bit, delta field
# Instruction / timing / timestamp packets # ------------------------------------------------------------------------
0x0F: "TIME_SHORT_DELTA_PLUS4", # short ts, raw_delta+4 # 0x070x0F: pure timestamp-ish deltas
0x11: "TIME_WAVE_STATE", # compact wave timing/stall state record # ------------------------------------------------------------------------
0x14: "INST_EXEC_RECORD", # per-instruction execution record 0x07: "TS_DELTA_S8_W3", # shift=8, width=3 (small delta)
0x16: "TIME_LONG_OR_MARKER", # long delta / marker with 6-byte payload 0x08: "EVT_MATCH_SMALL", # event-ish, see fields below
0x09: "PERF_ROUTE_CONFIG", # routing/indirection config
0x0A: "TS_DELTA_S5_W2_A", # shift=5, width=2
0x0B: "TS_DELTA_S5_W3_A", # shift=5, width=3
0x0C: "TS_DELTA_S5_W3_B", # shift=5, width=3 (different consumer)
0x0D: "TS_DELTA_S5_W3_C", # shift=5, width=3
0x0E: "TS_DELTA_S7_W2", # shift=7, width=2
0x0F: "TS_DELTA_SHORT_PLUS4", # short delta; ROCm adds +4 before accumulate
# State / control / perf snapshots # ------------------------------------------------------------------------
0x09: "CONTROL_CONFIG_32B", # 32-bit control/config word (bursts of FE88..., C488...) # 0x100x19: timestamps, layout headers, events, perf
0x15: "PERFCOUNTER_SNAPSHOT", # perf / TT configuration snapshot (8-byte) # ------------------------------------------------------------------------
0x10: "PSEUDO_NEED_MORE_BITS", # not a real packet; decoder refill hint
# Extra descriptors / events / metrics 0x11: "TS_WAVE_STATE_SAMPLE", # wave stall/termination sample (byte at +10)
0x06: "META_DESCRIPTOR_24B", # 24-bit descriptor (seen in complex kernels like GEMM) 0x12: "EVT_SECONDARY_METRIC24", # 24-bit secondary timing/perf metric
0x08: "EVENT_SMALL", # small in-stream event (5-nibble payload) 0x13: "EVT_SMALL_GENERIC", # same structural family as 0x08/0x12/0x19
0x12: "TIME_SECONDARY_METRIC", # 3-byte secondary timing/latency/perf metric
0x18: "EVENT_SMALL_PAYLOAD", # generic small side-band payload (5 nibbles)
0x19: "EVENT_SUMMARY_48B", # rare 6-byte summary/aggregate metric
# Pseudo / unknown / not yet observed 0x14: "INST_EXEC_OR_CFG", # instruction exec record / config write / COR marker
0x07: "UNK_DELTA", # unknown 0x15: "PERFCOUNTER_SNAPSHOT", # small delta + 50-ish bits of snapshot
0x0A: "UNK_DELTA2", # unknown 0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker
0x0B: "UNK_DELTA3", # unknown 0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B
0x0C: "UNK_DELTA4", # unknown 0x18: "PERF_EVENT_SELECT", # packed selector → FUN_0010aba0
0x0D: "UNK_DELTA5", # unknown 0x19: "EVT_SUMMARY_48B", # 6-byte summary/aggregate metric
0x0E: "UNK_DELTA6", # unknown
0x10: "UNK_PSEUDO", # not seen; pseudo/placeholder
0x17: "UNK_NO_DELTA", # unknown, likely non-timing event
} }
# these tables are from rocprof trace decoder # these tables are from rocprof trace decoder
@ -113,17 +117,13 @@ DELTA_MAP_DEFAULT = {
def decode_packet_fields(opcode: int, reg: int, delta: int) -> str: def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
""" """
Conservative decoding of a few packet types. Decode packet payloads conservatively, using:
- NIBBLE_BUDGET[opcode & 0x1F] to mask reg down to true width.
Rules: - DELTA_MAP_DEFAULT[opcode] to expose the "primary" field (often delta).
- We first mask the 64-bit shift register down to the actual packet - Per-opcode layouts derived from rocprof's decompiled consumers.
width using NIBBLE_BUDGET[opcode & 0x1F], so we never read bits
that aren't really part of the packet.
- Only layouts that are clearly visible from the decompiled C are
decoded, and names are kept generic (cfg_*, idx_*, id_*, etc).
""" """
# --- 0. Restrict to the real packet bits for this opcode ------------- # --- 0. Restrict to real packet bits ---------------------------------
nb_bits = NIBBLE_BUDGET[opcode & 0x1F] # this table is in bits nb_bits = NIBBLE_BUDGET[opcode & 0x1F]
if nb_bits <= 0 or nb_bits >= 64: if nb_bits <= 0 or nb_bits >= 64:
pkt = reg & ((1 << 64) - 1) pkt = reg & ((1 << 64) - 1)
else: else:
@ -131,23 +131,46 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
fields: list[str] = [] fields: list[str] = []
# --- 1. Timestamp-ish opcodes ---------------------------------------- shift, width = DELTA_MAP_DEFAULT.get(opcode, (0, 0))
if width:
field_mask = (1 << width) - 1
shaped_field = (pkt >> shift) & field_mask
else:
field_mask = 0
shaped_field = 0
if opcode == 0x0F: # TIME_SHORT_DELTA_PLUS4 # =====================================================================
# By the time we get here, `delta` is already raw_delta+4. # 1. Timestamp-centric opcodes (actually drive 'time')
# =====================================================================
if opcode == 0x0F: # TS_DELTA_SHORT_PLUS4
# In the caller, delta already has +4 applied.
raw_delta = shaped_field
fields.append(f"raw_delta={raw_delta}")
fields.append(f"ts_short_plus4={delta}") fields.append(f"ts_short_plus4={delta}")
return ", ".join(fields) return ", ".join(fields)
if opcode == 0x11: # TIME_WAVE_STATE (medium/large delta) if opcode == 0x11: # TS_WAVE_STATE_SAMPLE
shift, width = DELTA_MAP_DEFAULT[opcode] # DELTA_MAP_DEFAULT: shift=7, width=9 -> small delta.
raw_delta = (pkt >> shift) & ((1 << width) - 1) raw_delta = shaped_field
coarse = (pkt >> (shift + width)) & 0xFF # next byte above delta coarse = (pkt >> (shift + width)) & 0xFF # matches byte at +10 in C
fields.append(f"raw_delta={raw_delta}") fields.append(f"raw_delta={raw_delta}")
if coarse: if coarse:
fields.append(f"raw_coarse=0x{coarse:02x}") fields.append(f"coarse_state=0x{coarse:02x}")
# From decomp:
# - when layout<3 and coarse&1, it sets a "has interesting wave" flag
# - when coarse&8, it marks all live waves as "terminated"
if coarse & 0x01:
fields.append("flag_wave_interest=1")
if coarse & 0x08:
fields.append("flag_terminate_all=1")
return ", ".join(fields) return ", ".join(fields)
if opcode == 0x16: # TIME_LONG_OR_MARKER if opcode == 0x16: # TS_DELTA36_OR_MARK
# Bits:
# bit8 -> 0x100
# bit9 -> 0x200
# bits 12..47 -> 36-bit field used as delta or marker
bit8 = bool(pkt & 0x100) bit8 = bool(pkt & 0x100)
bit9 = bool(pkt & 0x200) bit9 = bool(pkt & 0x200)
if not bit9: if not bit9:
@ -159,40 +182,95 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
val36 = (pkt >> 12) & ((1 << 36) - 1) val36 = (pkt >> 12) & ((1 << 36) - 1)
fields.append(f"mode={mode}") fields.append(f"mode={mode}")
fields.append(f"val36=0x{val36:x}") fields.append(f"val36=0x{val36:x}")
if mode == "delta":
fields.append(f"delta36={delta}")
return ", ".join(fields) return ", ".join(fields)
# --- 2. Opcode 0x14: exec/config record ------------------------------ # For 0x07, 0x0A0x0E, we know they drive time (via DELTA_MAP_DEFAULT),
# but we don't see any other fields used in the decomp.
if opcode in (0x07, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E):
if width:
raw_delta = shaped_field
leftover = pkt & ~(field_mask << shift)
fields.append(f"raw_delta={raw_delta}")
if leftover:
fields.append(f"payload=0x{leftover:x}")
return ", ".join(fields)
if opcode == 0x14: # =====================================================================
subop = (pkt >> 16) & 0xFFFF # matches (short)(w >> 0x10) # 2. Small "meta + tiny delta" packets (0x010x06)
val32 = (pkt >> 32) & 0xFFFFFFFF # matches (uint)(w >> 0x20) # =====================================================================
slot = (pkt >> 7) & 0x7 # used as (idx & 4) + (idx & 3)
hi_byte = (pkt >> 8) & 0xFF if opcode == 0x01: # META_ID12_TS_SMALL
id12 = pkt & 0xFFF
fields.append(f"id12=0x{id12:03x}")
if width:
fields.append(f"field_s{shift}_w{width}={shaped_field}")
return ", ".join(fields)
if opcode == 0x02: # META_FLAG8_TS_SMALL
flag8 = pkt & 0xFF
fields.append(f"flag8=0x{flag8:02x}")
if width:
fields.append(f"field_s{shift}_w{width}={shaped_field}")
return ", ".join(fields)
if opcode == 0x03: # META_SUBEVENT8_TS_SMALL
sub8 = pkt & 0xFF
fields.append(f"subevent8=0x{sub8:02x}")
if width:
fields.append(f"field_s{shift}_w{width}={shaped_field}")
return ", ".join(fields)
if opcode == 0x04: # META_BASE_INDEX12_TS
idx12 = pkt & 0xFFF
fields.append(f"base_index12=0x{idx12:03x}")
if width:
fields.append(f"field_s{shift}_w{width}={shaped_field}")
return ", ".join(fields)
if opcode in (0x05, 0x06): # META_DESC24_TS_A/B
desc24 = pkt & 0xFFFFFF
fields.append(f"desc24=0x{desc24:06x}")
if width:
fields.append(f"field_s{shift}_w{width}={shaped_field}")
return ", ".join(fields)
# =====================================================================
# 3. Opcode 0x14: exec/config record (+ COR marker)
# =====================================================================
if opcode == 0x14: # INST_EXEC_OR_CFG
subop = (pkt >> 16) & 0xFFFF # (short)(w >> 0x10)
val32 = (pkt >> 32) & 0xFFFFFFFF # (uint)(w >> 0x20)
slot = (pkt >> 7) & 0x7 # index in local_168[...] tables
hi_byte = (pkt >> 8) & 0xFF # determines config vs marker
fields.append(f"subop=0x{subop:04x}") fields.append(f"subop=0x{subop:04x}")
fields.append(f"slot={slot}") fields.append(f"slot={slot}")
fields.append(f"val32=0x{val32:08x}") fields.append(f"val32=0x{val32:08x}")
if hi_byte & 0x80: if hi_byte & 0x80:
# "config" flavour, writes into local_168[...] etc. # Config flavour: writes config words into per-slot state arrays.
fields.append("kind=config") fields.append("kind=config")
if subop == 0x000C: if subop == 0x000C:
fields.append("cfg_target=local_168[slot].lo") fields.append("cfg_target=local_168[slot].lo")
elif subop == 0x000D: elif subop == 0x000D:
fields.append("cfg_target=local_168[slot].hi") fields.append("cfg_target=local_168[slot].hi")
else: else:
# COR marker: subop 0xC342, val32==0x434F5200 ("COR\0") # COR marker: subop 0xC342, payload "COR\0" → start of a COR region.
if subop == 0xC342: if subop == 0xC342:
fields.append("kind=cor_stream") fields.append("kind=cor_stream")
if val32 == 0x434F5200: if val32 == 0x434F5200:
fields.append("cor_magic='COR\\0'") fields.append("cor_magic='COR\\0'")
return ", ".join(fields) return ", ".join(fields)
# --- 3. Opcode 0x17: mode/layout header ------------------------------ # =====================================================================
# 4. Opcode 0x17: layout / mode header
# =====================================================================
if opcode == 0x17: if opcode == 0x17: # LAYOUT_MODE_HEADER
# From case 0x17: # From decomp (two sites with identical logic):
# layout = (w >> 7) & 0x3f # layout = (w >> 7) & 0x3f
# mode = (w >> 0xd) & 3 # mode = (w >> 0xd) & 3
# group = (w >> 0xf) & 7 # group = (w >> 0xf) & 7
@ -213,27 +291,26 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
fields.append(f"sel_b={sel_b}") fields.append(f"sel_b={sel_b}")
if layout == 4: if layout == 4:
fields.append(f"layout4_flag={flag4}") fields.append(f"layout4_flag={flag4}")
return ", ".join(fields) return ", ".join(fields)
# --- 4. Opcode 0x09: state-ish / indirection record ------------------ # =====================================================================
# 5. Opcode 0x09: state / route config record
# =====================================================================
if opcode == 0x09: if opcode == 0x09: # PERF_ROUTE_CONFIG
# From case 9 on puVar58[1] (here pkt): # From case 9 in multiple consumers:
# # flag7 = (w >> 7) & 1 (low bit of uVar41)
# uVar41 = (w & 0xffffffff) >> 7; local_520 = uVar41 & 1 # cls2 = (w >> 8) & 3 (class / group)
# local_4a0 = (w >> 8) & 3; # slot4 = (w >> 10) & 0xf (slot / group index)
# local_4a8 = (w >> 10) & (7 or 0xf) (depends on local_494) # idx_lo = (w >> 0xd) & 0x1f (low index, layout<4 path)
# uVar69 = (w >> 0xd) or (w >> 0xf) (depends on local_494) # idx_hi = (w >> 0xf) & 0x1f (high index, layout>=4 path)
# local_518 = (w >> 0x19) & 0x7f; # id7 = (w >> 0x19) & 0x7f (7-bit id)
# flag7 = (pkt >> 7) & 0x1
# We *dont* know local_494 here, so we just expose the raw slices. cls2 = (pkt >> 8) & 0x3
flag7 = (pkt >> 7) & 0x1 # low bit of uVar41 slot4 = (pkt >> 10) & 0xF
cls2 = (pkt >> 8) & 0x3 # local_4a0 idx_lo = (pkt >> 13) & 0x1F
slot4 = (pkt >> 10) & 0xF # superset of 3-bit local_4a8 idx_hi = (pkt >> 15) & 0x1F
idx_lo = (pkt >> 13) & 0x1F # matches uVar69&0x1F when layout<4 id7 = (pkt >> 0x19) & 0x7F
idx_hi = (pkt >> 15) & 0x1F # matches uVar69&0x1F when layout>=4
id7 = (pkt >> 0x19) & 0x7F # local_518
fields.append(f"flag7={flag7}") fields.append(f"flag7={flag7}")
fields.append(f"cls2={cls2}") fields.append(f"cls2={cls2}")
@ -243,18 +320,18 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
fields.append(f"id7=0x{id7:x}") fields.append(f"id7=0x{id7:x}")
return ", ".join(fields) return ", ".join(fields)
# --- 5. Opcode 0x18: perf/event trigger ------------------------------ # =====================================================================
# 6. Opcode 0x18: perf/event selector (FUN_0010aba0)
# =====================================================================
if opcode == 0x18: if opcode == 0x18: # PERF_EVENT_SELECT
# From case 0x18: # From case 0x18:
# - low 3 bits: (w & 7) # low3 = w & 7
# - mid 3 bits: (w >> 3) & 7 or (w >> 4) & 7 (layoutdependent) # grp3 = (w >> 3) or (w >> 4) & 7 (layout-dependent)
# - hi id: (w >> 0xc) & 0xff OR (w >> 0xd) & 0x7f # flags = bits 6 (B6) and 7 (B7)
# - flag bits at 6 / 7 # hi8 = (w >> 0xc) & 0xff (layout 4 path)
# # hi7 = (w >> 0xd) & 0x7f (other layouts)
# The *real* semantics depend on global local_494 and accumulated # idx5 = (w >> 7) or (w >> 8) & 0x1f, used as wave index
# local_500, so we keep this as a raw view thats still useful for
# debugging, but not layout-dependent.
low3 = pkt & 0x7 low3 = pkt & 0x7
grp3_a = (pkt >> 3) & 0x7 grp3_a = (pkt >> 3) & 0x7
grp3_b = (pkt >> 4) & 0x7 grp3_b = (pkt >> 4) & 0x7
@ -276,13 +353,34 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
fields.append(f"hi7=0x{hi7:02x}") fields.append(f"hi7=0x{hi7:02x}")
return ", ".join(fields) return ", ".join(fields)
# --- 6. Generic tiny event-ish packets ------------------------------- # =====================================================================
# 7. Opcode 0x15: perfcounter snapshot
# =====================================================================
if opcode in (0x08, 0x12, 0x19): if opcode == 0x15: # PERFCOUNTER_SNAPSHOT
# These are all "small event" style tokens. The exact layout depends # NIBBLE_BUDGET gives full 64 bits here.
# on global state (local_494 etc), so we just show: # DELTA_MAP_DEFAULT: shift=7, width=3 → tiny delta field.
# - low 8 bits as a kind/flag byte raw_delta = shaped_field if width else 0
# - the rest as an opaque payload. # low bits below the delta field
snap_low = pkt & ((1 << shift) - 1) if shift else 0
# everything above delta field
snap_hi = pkt >> (shift + width) if width else (pkt >> shift)
fields.append(f"raw_delta={raw_delta}")
fields.append(f"snap_low_s{shift}=0x{snap_low:x}")
fields.append(f"snap_hi=0x{snap_hi:x}")
return ", ".join(fields)
# =====================================================================
# 8. Small event-ish packets (0x08 / 0x12 / 0x13 / 0x19)
# =====================================================================
if opcode in (0x08, 0x12, 0x13, 0x19):
# These are all "small event / metric" style tokens. The exact semantics
# depend on layout (0x17) and accumulated state (local_500 etc), so we
# expose:
# - low 8 bits as kind byte
# - rest as opaque payload.
kind = pkt & 0xFF kind = pkt & 0xFF
payload = pkt >> 8 payload = pkt >> 8
fields.append(f"kind_byte=0x{kind:02x}") fields.append(f"kind_byte=0x{kind:02x}")
@ -290,10 +388,27 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
fields.append(f"payload=0x{payload:x}") fields.append(f"payload=0x{payload:x}")
return ", ".join(fields) return ", ".join(fields)
# --- 7. Everything else: no extra decode ----------------------------- # =====================================================================
return "" # 9. Pseudo opcode 0x10: never a "real" packet
# =====================================================================
def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000) -> None: if opcode == 0x10: # PSEUDO_NEED_MORE_BITS
# The main loop never prints these; they're just a control token.
return ""
# =====================================================================
# 10. Generic fallback: expose the DELTA_MAP_DEFAULT field + leftover
# =====================================================================
if width:
fields.append(f"field_s{shift}_w{width}={shaped_field}")
leftover = pkt & ~(field_mask << shift)
if leftover:
fields.append(f"payload=0x{leftover:x}")
return ", ".join(fields)
def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=None) -> None:
""" """
Minimal debug: print ONE LINE per decoded token (packet). Minimal debug: print ONE LINE per decoded token (packet).
@ -350,18 +465,22 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000) -> None:
if two_bits == 1: if two_bits == 1:
flags |= 0x01 flags |= 0x01
# Common 36-bit field at bits [12..47]
val36 = (reg >> 12) & ((1 << 36) - 1)
if (reg & 0x200) == 0: if (reg & 0x200) == 0:
# delta mode: 36-bit delta at bits [12..47] # delta mode: add 36-bit delta to time
delta = (reg >> 12) & ((1 << 36) - 1) delta = val36
time += delta time += delta
note = "0x16-delta" note = "0x16-delta"
else: else:
# marker mode if bit9==1 and bit8==0 # marker / other modes: no time advance
if (reg & 0x100) == 0: if (reg & 0x100) == 0 and val36 != 0:
val = (reg >> 12) & ((1 << 36) - 1) # real marker: bit9=1, bit8=0, non-zero payload
delta = 0 delta = 0
note = f"0x16-marker val=0x{val:x}" note = f"0x16-marker val=0x{val36:x}"
else: else:
# "other" 0x16 variants, ignored for timing
delta = 0 delta = 0
note = "0x16-other" note = "0x16-other"
else: else:
@ -387,8 +506,7 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000) -> None:
extra = decode_packet_fields(opcode, reg, delta) extra = decode_packet_fields(opcode, reg, delta)
if extra: note = (note + " ; " + extra) if note else extra if extra: note = (note + " ; " + extra) if note else extra
BORING_OPCODES = {0x11, 0x14} if filter is None or opcode not in filter:
if opcode not in BORING_OPCODES or getenv("BORING", 1):
my_reg = reg my_reg = reg
my_reg &= (1 << nib_budget) - 1 my_reg &= (1 << nib_budget) - 1
print( print(

View file

@ -111,7 +111,6 @@ def decode(profile:list[ProfileEvent]) -> _ROCParseCtx:
@rocprof.rocprof_trace_decoder_isa_callback_t @rocprof.rocprof_trace_decoder_isa_callback_t
def isa_cb(instr_ptr, mem_size_ptr, size_ptr, pc, data_ptr): def isa_cb(instr_ptr, mem_size_ptr, size_ptr, pc, data_ptr):
if DEBUG >= 8: print(f"isa_cb {pc.address=} {pc.code_object_id=}")
instr, mem_size_ptr[0] = ROCParseCtx.disasms[(unwrap(ROCParseCtx.active_kern), pc.address)] instr, mem_size_ptr[0] = ROCParseCtx.disasms[(unwrap(ROCParseCtx.active_kern), pc.address)]
# this is the number of bytes to next instruction, set to 0 for end_pgm # this is the number of bytes to next instruction, set to 0 for end_pgm

View file

@ -17,7 +17,7 @@ def helper_collect_profile(*devs):
cpu_events.clear() cpu_events.clear()
profile_list = [] profile_list = []
with Context(VIZ=1): with Context(VIZ=1, PROFILE=1):
yield profile_list yield profile_list
for dev in devs: dev.synchronize() for dev in devs: dev.synchronize()
for dev in devs: dev._at_profile_finalize() for dev in devs: dev._at_profile_finalize()

View file

@ -66,6 +66,7 @@ class TestProgressBar(unittest.TestCase):
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test") tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output) self._compare_bars(tinytqdm_output, tqdm_output)
@unittest.skip("this is flaky")
@patch('sys.stderr', new_callable=StringIO) @patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size') @patch('shutil.get_terminal_size')
def test_unit_scale(self, mock_terminal_size, mock_stderr): def test_unit_scale(self, mock_terminal_size, mock_stderr):

View file

@ -6,6 +6,7 @@ from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatch
from tinygrad.uop.symbolic import sym from tinygrad.uop.symbolic import sym
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from tinygrad.helpers import PROFILE, colored, ansistrip, flatten, TracingKey, ProfileRangeEvent, ProfileEvent, Context, cpu_events, profile_marker from tinygrad.helpers import PROFILE, colored, ansistrip, flatten, TracingKey, ProfileRangeEvent, ProfileEvent, Context, cpu_events, profile_marker
from tinygrad.helpers import VIZ
from tinygrad.device import Buffer from tinygrad.device import Buffer
@track_rewrites(name=True) @track_rewrites(name=True)
@ -33,11 +34,14 @@ class BaseTestViz(unittest.TestCase):
cpu_events.clear() cpu_events.clear()
self.tms = TRACK_MATCH_STATS.value self.tms = TRACK_MATCH_STATS.value
self.profile = PROFILE.value self.profile = PROFILE.value
self.viz = VIZ.value
TRACK_MATCH_STATS.value = 2 TRACK_MATCH_STATS.value = 2
PROFILE.value = 1 PROFILE.value = 1
VIZ.value = 1
def tearDown(self): def tearDown(self):
TRACK_MATCH_STATS.value = self.tms TRACK_MATCH_STATS.value = self.tms
PROFILE.value = self.profile PROFILE.value = self.profile
VIZ.value = self.viz
class TestViz(BaseTestViz): class TestViz(BaseTestViz):
def test_simple(self): def test_simple(self):

View file

@ -5,7 +5,7 @@ from typing import Any, Generic, TypeVar, Iterator, Sequence, cast, Generator
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM
from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
from tinygrad.helpers import unwrap_class_type, suppress_finalizing, AMD_LLVM, select_first_inited from tinygrad.helpers import unwrap_class_type, suppress_finalizing, AMD_LLVM, select_first_inited, VIZ
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
@ -355,8 +355,9 @@ if PROFILE:
with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(cpu_events+Compiled.profile_events+Buffer.profile_events, f) with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(cpu_events+Compiled.profile_events+Buffer.profile_events, f)
from tinygrad.uop.ops import launch_viz if VIZ:
launch_viz("PROFILE", fn) from tinygrad.uop.ops import launch_viz
launch_viz("PROFILE", fn)
def enumerate_devices_str() -> Generator[str, None, None]: def enumerate_devices_str() -> Generator[str, None, None]:
from tinygrad import Tensor, Device from tinygrad import Tensor, Device

View file

@ -179,7 +179,9 @@ ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), Conte
EMULATE = ContextVar("EMULATE", "") EMULATE = ContextVar("EMULATE", "")
CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else (os.cpu_count() or 1))) CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else (os.cpu_count() or 1)))
CPU_LLVM, CPU_LVP, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("CPU_LVP", 0), ContextVar("AMD_LLVM", 0) CPU_LLVM, CPU_LVP, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("CPU_LVP", 0), ContextVar("AMD_LLVM", 0)
VIZ = PROFILE = ContextVar("VIZ", 0) # VIZ implies PROFILE, but you can run PROFILE without VIZ
VIZ = ContextVar("VIZ", 0)
PROFILE = ContextVar("PROFILE", VIZ.value)
SPEC = ContextVar("SPEC", 1) SPEC = ContextVar("SPEC", 1)
# TODO: disable by default due to speed # TODO: disable by default due to speed
IGNORE_OOB = ContextVar("IGNORE_OOB", 1) IGNORE_OOB = ContextVar("IGNORE_OOB", 1)