mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
11 commits
master
...
parse_sqtt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8ff531b68 | ||
|
|
c9254c32df | ||
|
|
ddaaeb16de | ||
|
|
5fba9ccb85 | ||
|
|
2136c76fa8 | ||
|
|
f9f5fd2b41 | ||
|
|
dac087f2b7 | ||
|
|
1129e4c5d5 | ||
|
|
aaec1130a3 | ||
|
|
24c7f38105 | ||
|
|
22432917d3 |
9 changed files with 349 additions and 123 deletions
|
|
@ -1,8 +1,10 @@
|
|||
import pathlib
|
||||
import os, pathlib
|
||||
|
||||
# TODO: there is a timing bug without this
|
||||
os.environ["AMD_AQL"] = "1"
|
||||
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.runtime.ops_amd import AMDProgram, HIPCompiler
|
||||
import time
|
||||
import os
|
||||
|
||||
NUM_WORKGROUPS = 96
|
||||
WAVE_SIZE = 32
|
||||
|
|
@ -44,9 +46,9 @@ if __name__=="__main__":
|
|||
raise RuntimeError("Error while initiating AMD device")
|
||||
|
||||
COMPILER = HIPCompiler(DEV.arch)
|
||||
if DEV.arch in {'gfx1100', 'gfx1103'}:
|
||||
if DEV.arch == 'gfx1103':
|
||||
NUM_WORKGROUPS = 8
|
||||
if DEV.arch in {'gfx1100', 'gfx1103', 'gfx1151'}:
|
||||
if DEV.arch == 'gfx1103': NUM_WORKGROUPS = 8
|
||||
if DEV.arch == 'gfx1151': NUM_WORKGROUPS = 40
|
||||
launchBenchmark("v_wmma_bf16_16x16x16_bf16", (7,8,15))
|
||||
launchBenchmark("v_wmma_f16_16x16x16_f16", (7,8,15))
|
||||
launchBenchmark("v_wmma_f32_16x16x16_bf16", (7,8,15))
|
||||
|
|
|
|||
99
extra/sqtt/active_sqtt_parse.py
Normal file
99
extra/sqtt/active_sqtt_parse.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
import os
|
||||
os.environ["PYTHONPATH"] = "."
|
||||
os.environ["SQTT"] = "1"
|
||||
if "DEV" not in os.environ: os.environ["DEV"] = "AMD"
|
||||
os.environ["PROFILE"] = "1"
|
||||
os.environ["AMD_LLVM"] = "0"
|
||||
|
||||
from dataclasses import replace
|
||||
import atexit, contextlib
|
||||
from tinygrad.helpers import system, getenv
|
||||
from tinygrad.runtime.ops_amd import AMDProgram
|
||||
from extra.sqtt.roc import decode, WaveExec, ProfileSQTTEvent
|
||||
from tinygrad.device import Device, ProfileDeviceEvent
|
||||
|
||||
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
|
||||
|
||||
def set_power(x): system(f"sudo /opt/rocm/bin/amd-smi set -l {x}")
|
||||
@atexit.register
|
||||
def reset_power(): set_power("auto")
|
||||
set_power("stable_std")
|
||||
|
||||
dev = Device["AMD"]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def save_sqtt():
|
||||
# clear the old traces
|
||||
dev.profile_events.clear()
|
||||
sqtt:dict[str, list[WaveExec]] = {}
|
||||
yield sqtt
|
||||
events = dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())]
|
||||
|
||||
rctx = decode(events)
|
||||
assert len(rctx.inst_execs) > 0, "empty sqtt output"
|
||||
sqtt.update(rctx.inst_execs)
|
||||
|
||||
for e in events:
|
||||
if isinstance(e, ProfileSQTTEvent):
|
||||
print(replace(e, blob=b''))
|
||||
if e.se == 0:
|
||||
parse_sqtt_print_packets(e.blob, filter=[0xf, 0x11, 0x12, 0x14] if getenv("FILTER", 1) else None)
|
||||
|
||||
|
||||
template = """.text
|
||||
.globl matmul
|
||||
.p2align 8
|
||||
.type matmul,@function
|
||||
matmul:
|
||||
INSTRUCTION
|
||||
s_endpgm
|
||||
|
||||
.rodata
|
||||
.p2align 6
|
||||
.amdhsa_kernel matmul
|
||||
.amdhsa_next_free_vgpr .amdgcn.next_free_vgpr
|
||||
.amdhsa_next_free_sgpr .amdgcn.next_free_sgpr
|
||||
.amdhsa_wavefront_size32 1
|
||||
.end_amdhsa_kernel
|
||||
|
||||
.amdgpu_metadata
|
||||
---
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 0
|
||||
amdhsa.kernels:
|
||||
- .name: matmul
|
||||
.symbol: matmul.kd
|
||||
.kernarg_segment_size: 0
|
||||
.group_segment_fixed_size: 0
|
||||
.private_segment_fixed_size: 0
|
||||
.kernarg_segment_align: 4
|
||||
.wavefront_size: 32
|
||||
.sgpr_count: 8
|
||||
.vgpr_count: 32
|
||||
.max_flat_workgroup_size: 1024
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
"""
|
||||
|
||||
def run_asm(src):
|
||||
NUM_WORKGROUPS = 1
|
||||
WAVE_SIZE = 32
|
||||
NUM_WAVES = 1
|
||||
lib = dev.compiler.compile(template.replace("INSTRUCTION", '\n'.join(src)))
|
||||
dev.compiler.disassemble(lib)
|
||||
fxn = AMDProgram(dev, "matmul", lib)
|
||||
fxn(global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
with save_sqtt() as sqtt:
|
||||
run_asm([
|
||||
#"v_rcp_f32 v1, v0"
|
||||
"v_add_f32_e32 v1 v0 v0",
|
||||
"v_add_f32_e32 v3 v2 v2",
|
||||
"v_add_f32_e32 v5 v4 v4",
|
||||
"v_add_f32_e32 v7 v6 v6",
|
||||
#"v_add_f32_e32 v1 v0 v0",
|
||||
#"v_add_f32_e32 v2 v1 v1",
|
||||
#"s_nop 1"
|
||||
]*1)
|
||||
|
|
@ -1,44 +1,48 @@
|
|||
import pickle
|
||||
from hexdump import hexdump
|
||||
from extra.sqtt.roc import decode, ProfileSQTTEvent
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
# Instruction packets (one per ISA op)
|
||||
# NOTE: these are bad guesses and may be wrong! feel free to update if you know better
|
||||
|
||||
OPCODE_NAMES = {
|
||||
# Small metadata / structural packets (NOT ISA op kinds)
|
||||
0x01: "META_SMALL_ID", # 12-bit identifier / slot tag
|
||||
0x02: "META_FLAG", # 1-byte flag/mode (CF/AF/8F/DF...)
|
||||
0x03: "META_SUBEVENT_CODE", # 1-byte sub-event/classification code
|
||||
0x04: "META_BASE_INDEX_TAG", # 12-bit base index/tag (..D, 9D, 10D, 58D...)
|
||||
# ------------------------------------------------------------------------
|
||||
# 0x01–0x06: small “meta + maybe tiny delta” packets
|
||||
# ------------------------------------------------------------------------
|
||||
0x01: "META_ID12_TS_SMALL", # 12-bit ID + 3-bit delta field
|
||||
0x02: "META_FLAG8_TS_SMALL", # 8-bit flag/mode + small delta
|
||||
0x03: "META_SUBEVENT8_TS_SMALL", # 8-bit subevent/class + small delta
|
||||
0x04: "META_BASE_INDEX12_TS", # 12-bit base index + small delta
|
||||
0x05: "META_DESC24_TS_A", # 24-bit descriptor-ish + delta field
|
||||
0x06: "META_DESC24_TS_B", # second flavour, 24-bit, delta field
|
||||
|
||||
# Instruction / timing / timestamp packets
|
||||
0x0F: "TIME_SHORT_DELTA_PLUS4", # short ts, raw_delta+4
|
||||
0x11: "TIME_WAVE_STATE", # compact wave timing/stall state record
|
||||
0x14: "INST_EXEC_RECORD", # per-instruction execution record
|
||||
0x16: "TIME_LONG_OR_MARKER", # long delta / marker with 6-byte payload
|
||||
# ------------------------------------------------------------------------
|
||||
# 0x07–0x0F: pure timestamp-ish deltas
|
||||
# ------------------------------------------------------------------------
|
||||
0x07: "TS_DELTA_S8_W3", # shift=8, width=3 (small delta)
|
||||
0x08: "EVT_MATCH_SMALL", # event-ish, see fields below
|
||||
0x09: "PERF_ROUTE_CONFIG", # routing/indirection config
|
||||
0x0A: "TS_DELTA_S5_W2_A", # shift=5, width=2
|
||||
0x0B: "TS_DELTA_S5_W3_A", # shift=5, width=3
|
||||
0x0C: "TS_DELTA_S5_W3_B", # shift=5, width=3 (different consumer)
|
||||
0x0D: "TS_DELTA_S5_W3_C", # shift=5, width=3
|
||||
0x0E: "TS_DELTA_S7_W2", # shift=7, width=2
|
||||
0x0F: "TS_DELTA_SHORT_PLUS4", # short delta; ROCm adds +4 before accumulate
|
||||
|
||||
# State / control / perf snapshots
|
||||
0x09: "CONTROL_CONFIG_32B", # 32-bit control/config word (bursts of FE88..., C488...)
|
||||
0x15: "PERFCOUNTER_SNAPSHOT", # perf / TT configuration snapshot (8-byte)
|
||||
# ------------------------------------------------------------------------
|
||||
# 0x10–0x19: timestamps, layout headers, events, perf
|
||||
# ------------------------------------------------------------------------
|
||||
0x10: "PSEUDO_NEED_MORE_BITS", # not a real packet; decoder refill hint
|
||||
|
||||
# Extra descriptors / events / metrics
|
||||
0x06: "META_DESCRIPTOR_24B", # 24-bit descriptor (seen in complex kernels like GEMM)
|
||||
0x08: "EVENT_SMALL", # small in-stream event (5-nibble payload)
|
||||
0x12: "TIME_SECONDARY_METRIC", # 3-byte secondary timing/latency/perf metric
|
||||
0x18: "EVENT_SMALL_PAYLOAD", # generic small side-band payload (5 nibbles)
|
||||
0x19: "EVENT_SUMMARY_48B", # rare 6-byte summary/aggregate metric
|
||||
0x11: "TS_WAVE_STATE_SAMPLE", # wave stall/termination sample (byte at +10)
|
||||
0x12: "EVT_SECONDARY_METRIC24", # 24-bit secondary timing/perf metric
|
||||
0x13: "EVT_SMALL_GENERIC", # same structural family as 0x08/0x12/0x19
|
||||
|
||||
# Pseudo / unknown / not yet observed
|
||||
0x07: "UNK_DELTA", # unknown
|
||||
0x0A: "UNK_DELTA2", # unknown
|
||||
0x0B: "UNK_DELTA3", # unknown
|
||||
0x0C: "UNK_DELTA4", # unknown
|
||||
0x0D: "UNK_DELTA5", # unknown
|
||||
0x0E: "UNK_DELTA6", # unknown
|
||||
0x10: "UNK_PSEUDO", # not seen; pseudo/placeholder
|
||||
0x17: "UNK_NO_DELTA", # unknown, likely non-timing event
|
||||
0x14: "INST_EXEC_OR_CFG", # instruction exec record / config write / COR marker
|
||||
0x15: "PERFCOUNTER_SNAPSHOT", # small delta + 50-ish bits of snapshot
|
||||
0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker
|
||||
0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B
|
||||
0x18: "PERF_EVENT_SELECT", # packed selector → FUN_0010aba0
|
||||
0x19: "EVT_SUMMARY_48B", # 6-byte summary/aggregate metric
|
||||
}
|
||||
|
||||
# these tables are from rocprof trace decoder
|
||||
|
|
@ -113,17 +117,13 @@ DELTA_MAP_DEFAULT = {
|
|||
|
||||
def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
|
||||
"""
|
||||
Conservative decoding of a few packet types.
|
||||
|
||||
Rules:
|
||||
- We first mask the 64-bit shift register down to the actual packet
|
||||
width using NIBBLE_BUDGET[opcode & 0x1F], so we never read bits
|
||||
that aren't really part of the packet.
|
||||
- Only layouts that are clearly visible from the decompiled C are
|
||||
decoded, and names are kept generic (cfg_*, idx_*, id_*, etc).
|
||||
Decode packet payloads conservatively, using:
|
||||
- NIBBLE_BUDGET[opcode & 0x1F] to mask reg down to true width.
|
||||
- DELTA_MAP_DEFAULT[opcode] to expose the "primary" field (often delta).
|
||||
- Per-opcode layouts derived from rocprof's decompiled consumers.
|
||||
"""
|
||||
# --- 0. Restrict to the real packet bits for this opcode -------------
|
||||
nb_bits = NIBBLE_BUDGET[opcode & 0x1F] # this table is in bits
|
||||
# --- 0. Restrict to real packet bits ---------------------------------
|
||||
nb_bits = NIBBLE_BUDGET[opcode & 0x1F]
|
||||
if nb_bits <= 0 or nb_bits >= 64:
|
||||
pkt = reg & ((1 << 64) - 1)
|
||||
else:
|
||||
|
|
@ -131,23 +131,46 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
|
|||
|
||||
fields: list[str] = []
|
||||
|
||||
# --- 1. Timestamp-ish opcodes ----------------------------------------
|
||||
shift, width = DELTA_MAP_DEFAULT.get(opcode, (0, 0))
|
||||
if width:
|
||||
field_mask = (1 << width) - 1
|
||||
shaped_field = (pkt >> shift) & field_mask
|
||||
else:
|
||||
field_mask = 0
|
||||
shaped_field = 0
|
||||
|
||||
if opcode == 0x0F: # TIME_SHORT_DELTA_PLUS4
|
||||
# By the time we get here, `delta` is already raw_delta+4.
|
||||
# =====================================================================
|
||||
# 1. Timestamp-centric opcodes (actually drive 'time')
|
||||
# =====================================================================
|
||||
|
||||
if opcode == 0x0F: # TS_DELTA_SHORT_PLUS4
|
||||
# In the caller, delta already has +4 applied.
|
||||
raw_delta = shaped_field
|
||||
fields.append(f"raw_delta={raw_delta}")
|
||||
fields.append(f"ts_short_plus4={delta}")
|
||||
return ", ".join(fields)
|
||||
|
||||
if opcode == 0x11: # TIME_WAVE_STATE (medium/large delta)
|
||||
shift, width = DELTA_MAP_DEFAULT[opcode]
|
||||
raw_delta = (pkt >> shift) & ((1 << width) - 1)
|
||||
coarse = (pkt >> (shift + width)) & 0xFF # next byte above delta
|
||||
if opcode == 0x11: # TS_WAVE_STATE_SAMPLE
|
||||
# DELTA_MAP_DEFAULT: shift=7, width=9 -> small delta.
|
||||
raw_delta = shaped_field
|
||||
coarse = (pkt >> (shift + width)) & 0xFF # matches byte at +10 in C
|
||||
fields.append(f"raw_delta={raw_delta}")
|
||||
if coarse:
|
||||
fields.append(f"raw_coarse=0x{coarse:02x}")
|
||||
fields.append(f"coarse_state=0x{coarse:02x}")
|
||||
# From decomp:
|
||||
# - when layout<3 and coarse&1, it sets a "has interesting wave" flag
|
||||
# - when coarse&8, it marks all live waves as "terminated"
|
||||
if coarse & 0x01:
|
||||
fields.append("flag_wave_interest=1")
|
||||
if coarse & 0x08:
|
||||
fields.append("flag_terminate_all=1")
|
||||
return ", ".join(fields)
|
||||
|
||||
if opcode == 0x16: # TIME_LONG_OR_MARKER
|
||||
if opcode == 0x16: # TS_DELTA36_OR_MARK
|
||||
# Bits:
|
||||
# bit8 -> 0x100
|
||||
# bit9 -> 0x200
|
||||
# bits 12..47 -> 36-bit field used as delta or marker
|
||||
bit8 = bool(pkt & 0x100)
|
||||
bit9 = bool(pkt & 0x200)
|
||||
if not bit9:
|
||||
|
|
@ -159,40 +182,95 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
|
|||
val36 = (pkt >> 12) & ((1 << 36) - 1)
|
||||
fields.append(f"mode={mode}")
|
||||
fields.append(f"val36=0x{val36:x}")
|
||||
if mode == "delta":
|
||||
fields.append(f"delta36={delta}")
|
||||
return ", ".join(fields)
|
||||
|
||||
# --- 2. Opcode 0x14: exec/config record ------------------------------
|
||||
# For 0x07, 0x0A–0x0E, we know they drive time (via DELTA_MAP_DEFAULT),
|
||||
# but we don't see any other fields used in the decomp.
|
||||
if opcode in (0x07, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E):
|
||||
if width:
|
||||
raw_delta = shaped_field
|
||||
leftover = pkt & ~(field_mask << shift)
|
||||
fields.append(f"raw_delta={raw_delta}")
|
||||
if leftover:
|
||||
fields.append(f"payload=0x{leftover:x}")
|
||||
return ", ".join(fields)
|
||||
|
||||
if opcode == 0x14:
|
||||
subop = (pkt >> 16) & 0xFFFF # matches (short)(w >> 0x10)
|
||||
val32 = (pkt >> 32) & 0xFFFFFFFF # matches (uint)(w >> 0x20)
|
||||
slot = (pkt >> 7) & 0x7 # used as (idx & 4) + (idx & 3)
|
||||
hi_byte = (pkt >> 8) & 0xFF
|
||||
# =====================================================================
|
||||
# 2. Small "meta + tiny delta" packets (0x01–0x06)
|
||||
# =====================================================================
|
||||
|
||||
if opcode == 0x01: # META_ID12_TS_SMALL
|
||||
id12 = pkt & 0xFFF
|
||||
fields.append(f"id12=0x{id12:03x}")
|
||||
if width:
|
||||
fields.append(f"field_s{shift}_w{width}={shaped_field}")
|
||||
return ", ".join(fields)
|
||||
|
||||
if opcode == 0x02: # META_FLAG8_TS_SMALL
|
||||
flag8 = pkt & 0xFF
|
||||
fields.append(f"flag8=0x{flag8:02x}")
|
||||
if width:
|
||||
fields.append(f"field_s{shift}_w{width}={shaped_field}")
|
||||
return ", ".join(fields)
|
||||
|
||||
if opcode == 0x03: # META_SUBEVENT8_TS_SMALL
|
||||
sub8 = pkt & 0xFF
|
||||
fields.append(f"subevent8=0x{sub8:02x}")
|
||||
if width:
|
||||
fields.append(f"field_s{shift}_w{width}={shaped_field}")
|
||||
return ", ".join(fields)
|
||||
|
||||
if opcode == 0x04: # META_BASE_INDEX12_TS
|
||||
idx12 = pkt & 0xFFF
|
||||
fields.append(f"base_index12=0x{idx12:03x}")
|
||||
if width:
|
||||
fields.append(f"field_s{shift}_w{width}={shaped_field}")
|
||||
return ", ".join(fields)
|
||||
|
||||
if opcode in (0x05, 0x06): # META_DESC24_TS_A/B
|
||||
desc24 = pkt & 0xFFFFFF
|
||||
fields.append(f"desc24=0x{desc24:06x}")
|
||||
if width:
|
||||
fields.append(f"field_s{shift}_w{width}={shaped_field}")
|
||||
return ", ".join(fields)
|
||||
|
||||
# =====================================================================
|
||||
# 3. Opcode 0x14: exec/config record (+ COR marker)
|
||||
# =====================================================================
|
||||
|
||||
if opcode == 0x14: # INST_EXEC_OR_CFG
|
||||
subop = (pkt >> 16) & 0xFFFF # (short)(w >> 0x10)
|
||||
val32 = (pkt >> 32) & 0xFFFFFFFF # (uint)(w >> 0x20)
|
||||
slot = (pkt >> 7) & 0x7 # index in local_168[...] tables
|
||||
hi_byte = (pkt >> 8) & 0xFF # determines config vs marker
|
||||
|
||||
fields.append(f"subop=0x{subop:04x}")
|
||||
fields.append(f"slot={slot}")
|
||||
fields.append(f"val32=0x{val32:08x}")
|
||||
|
||||
if hi_byte & 0x80:
|
||||
# "config" flavour, writes into local_168[...] etc.
|
||||
# Config flavour: writes config words into per-slot state arrays.
|
||||
fields.append("kind=config")
|
||||
if subop == 0x000C:
|
||||
fields.append("cfg_target=local_168[slot].lo")
|
||||
elif subop == 0x000D:
|
||||
fields.append("cfg_target=local_168[slot].hi")
|
||||
else:
|
||||
# COR marker: subop 0xC342, val32==0x434F5200 ("COR\0")
|
||||
# COR marker: subop 0xC342, payload "COR\0" → start of a COR region.
|
||||
if subop == 0xC342:
|
||||
fields.append("kind=cor_stream")
|
||||
if val32 == 0x434F5200:
|
||||
fields.append("cor_magic='COR\\0'")
|
||||
|
||||
return ", ".join(fields)
|
||||
|
||||
# --- 3. Opcode 0x17: mode/layout header ------------------------------
|
||||
# =====================================================================
|
||||
# 4. Opcode 0x17: layout / mode header
|
||||
# =====================================================================
|
||||
|
||||
if opcode == 0x17:
|
||||
# From case 0x17:
|
||||
if opcode == 0x17: # LAYOUT_MODE_HEADER
|
||||
# From decomp (two sites with identical logic):
|
||||
# layout = (w >> 7) & 0x3f
|
||||
# mode = (w >> 0xd) & 3
|
||||
# group = (w >> 0xf) & 7
|
||||
|
|
@ -213,27 +291,26 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
|
|||
fields.append(f"sel_b={sel_b}")
|
||||
if layout == 4:
|
||||
fields.append(f"layout4_flag={flag4}")
|
||||
|
||||
return ", ".join(fields)
|
||||
|
||||
# --- 4. Opcode 0x09: state-ish / indirection record ------------------
|
||||
# =====================================================================
|
||||
# 5. Opcode 0x09: state / route config record
|
||||
# =====================================================================
|
||||
|
||||
if opcode == 0x09:
|
||||
# From case 9 on puVar58[1] (here pkt):
|
||||
#
|
||||
# uVar41 = (w & 0xffffffff) >> 7; local_520 = uVar41 & 1
|
||||
# local_4a0 = (w >> 8) & 3;
|
||||
# local_4a8 = (w >> 10) & (7 or 0xf) (depends on local_494)
|
||||
# uVar69 = (w >> 0xd) or (w >> 0xf) (depends on local_494)
|
||||
# local_518 = (w >> 0x19) & 0x7f;
|
||||
#
|
||||
# We *don’t* know local_494 here, so we just expose the raw slices.
|
||||
flag7 = (pkt >> 7) & 0x1 # low bit of uVar41
|
||||
cls2 = (pkt >> 8) & 0x3 # local_4a0
|
||||
slot4 = (pkt >> 10) & 0xF # superset of 3-bit local_4a8
|
||||
idx_lo = (pkt >> 13) & 0x1F # matches uVar69&0x1F when layout<4
|
||||
idx_hi = (pkt >> 15) & 0x1F # matches uVar69&0x1F when layout>=4
|
||||
id7 = (pkt >> 0x19) & 0x7F # local_518
|
||||
if opcode == 0x09: # PERF_ROUTE_CONFIG
|
||||
# From case 9 in multiple consumers:
|
||||
# flag7 = (w >> 7) & 1 (low bit of uVar41)
|
||||
# cls2 = (w >> 8) & 3 (class / group)
|
||||
# slot4 = (w >> 10) & 0xf (slot / group index)
|
||||
# idx_lo = (w >> 0xd) & 0x1f (low index, layout<4 path)
|
||||
# idx_hi = (w >> 0xf) & 0x1f (high index, layout>=4 path)
|
||||
# id7 = (w >> 0x19) & 0x7f (7-bit id)
|
||||
flag7 = (pkt >> 7) & 0x1
|
||||
cls2 = (pkt >> 8) & 0x3
|
||||
slot4 = (pkt >> 10) & 0xF
|
||||
idx_lo = (pkt >> 13) & 0x1F
|
||||
idx_hi = (pkt >> 15) & 0x1F
|
||||
id7 = (pkt >> 0x19) & 0x7F
|
||||
|
||||
fields.append(f"flag7={flag7}")
|
||||
fields.append(f"cls2={cls2}")
|
||||
|
|
@ -243,18 +320,18 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
|
|||
fields.append(f"id7=0x{id7:x}")
|
||||
return ", ".join(fields)
|
||||
|
||||
# --- 5. Opcode 0x18: perf/event trigger ------------------------------
|
||||
# =====================================================================
|
||||
# 6. Opcode 0x18: perf/event selector (FUN_0010aba0)
|
||||
# =====================================================================
|
||||
|
||||
if opcode == 0x18:
|
||||
if opcode == 0x18: # PERF_EVENT_SELECT
|
||||
# From case 0x18:
|
||||
# - low 3 bits: (w & 7)
|
||||
# - mid 3 bits: (w >> 3) & 7 or (w >> 4) & 7 (layout–dependent)
|
||||
# - hi id: (w >> 0xc) & 0xff OR (w >> 0xd) & 0x7f
|
||||
# - flag bits at 6 / 7
|
||||
#
|
||||
# The *real* semantics depend on global local_494 and accumulated
|
||||
# local_500, so we keep this as a raw view that’s still useful for
|
||||
# debugging, but not layout-dependent.
|
||||
# low3 = w & 7
|
||||
# grp3 = (w >> 3) or (w >> 4) & 7 (layout-dependent)
|
||||
# flags = bits 6 (B6) and 7 (B7)
|
||||
# hi8 = (w >> 0xc) & 0xff (layout 4 path)
|
||||
# hi7 = (w >> 0xd) & 0x7f (other layouts)
|
||||
# idx5 = (w >> 7) or (w >> 8) & 0x1f, used as wave index
|
||||
low3 = pkt & 0x7
|
||||
grp3_a = (pkt >> 3) & 0x7
|
||||
grp3_b = (pkt >> 4) & 0x7
|
||||
|
|
@ -276,13 +353,34 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
|
|||
fields.append(f"hi7=0x{hi7:02x}")
|
||||
return ", ".join(fields)
|
||||
|
||||
# --- 6. Generic tiny event-ish packets -------------------------------
|
||||
# =====================================================================
|
||||
# 7. Opcode 0x15: perfcounter snapshot
|
||||
# =====================================================================
|
||||
|
||||
if opcode in (0x08, 0x12, 0x19):
|
||||
# These are all "small event" style tokens. The exact layout depends
|
||||
# on global state (local_494 etc), so we just show:
|
||||
# - low 8 bits as a kind/flag byte
|
||||
# - the rest as an opaque payload.
|
||||
if opcode == 0x15: # PERFCOUNTER_SNAPSHOT
|
||||
# NIBBLE_BUDGET gives full 64 bits here.
|
||||
# DELTA_MAP_DEFAULT: shift=7, width=3 → tiny delta field.
|
||||
raw_delta = shaped_field if width else 0
|
||||
# low bits below the delta field
|
||||
snap_low = pkt & ((1 << shift) - 1) if shift else 0
|
||||
# everything above delta field
|
||||
snap_hi = pkt >> (shift + width) if width else (pkt >> shift)
|
||||
|
||||
fields.append(f"raw_delta={raw_delta}")
|
||||
fields.append(f"snap_low_s{shift}=0x{snap_low:x}")
|
||||
fields.append(f"snap_hi=0x{snap_hi:x}")
|
||||
return ", ".join(fields)
|
||||
|
||||
# =====================================================================
|
||||
# 8. Small event-ish packets (0x08 / 0x12 / 0x13 / 0x19)
|
||||
# =====================================================================
|
||||
|
||||
if opcode in (0x08, 0x12, 0x13, 0x19):
|
||||
# These are all "small event / metric" style tokens. The exact semantics
|
||||
# depend on layout (0x17) and accumulated state (local_500 etc), so we
|
||||
# expose:
|
||||
# - low 8 bits as kind byte
|
||||
# - rest as opaque payload.
|
||||
kind = pkt & 0xFF
|
||||
payload = pkt >> 8
|
||||
fields.append(f"kind_byte=0x{kind:02x}")
|
||||
|
|
@ -290,10 +388,27 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
|
|||
fields.append(f"payload=0x{payload:x}")
|
||||
return ", ".join(fields)
|
||||
|
||||
# --- 7. Everything else: no extra decode -----------------------------
|
||||
return ""
|
||||
# =====================================================================
|
||||
# 9. Pseudo opcode 0x10: never a "real" packet
|
||||
# =====================================================================
|
||||
|
||||
def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000) -> None:
|
||||
if opcode == 0x10: # PSEUDO_NEED_MORE_BITS
|
||||
# The main loop never prints these; they're just a control token.
|
||||
return ""
|
||||
|
||||
# =====================================================================
|
||||
# 10. Generic fallback: expose the DELTA_MAP_DEFAULT field + leftover
|
||||
# =====================================================================
|
||||
|
||||
if width:
|
||||
fields.append(f"field_s{shift}_w{width}={shaped_field}")
|
||||
leftover = pkt & ~(field_mask << shift)
|
||||
if leftover:
|
||||
fields.append(f"payload=0x{leftover:x}")
|
||||
|
||||
return ", ".join(fields)
|
||||
|
||||
def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=None) -> None:
|
||||
"""
|
||||
Minimal debug: print ONE LINE per decoded token (packet).
|
||||
|
||||
|
|
@ -350,18 +465,22 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000) -> None:
|
|||
if two_bits == 1:
|
||||
flags |= 0x01
|
||||
|
||||
# Common 36-bit field at bits [12..47]
|
||||
val36 = (reg >> 12) & ((1 << 36) - 1)
|
||||
|
||||
if (reg & 0x200) == 0:
|
||||
# delta mode: 36-bit delta at bits [12..47]
|
||||
delta = (reg >> 12) & ((1 << 36) - 1)
|
||||
# delta mode: add 36-bit delta to time
|
||||
delta = val36
|
||||
time += delta
|
||||
note = "0x16-delta"
|
||||
else:
|
||||
# marker mode if bit9==1 and bit8==0
|
||||
if (reg & 0x100) == 0:
|
||||
val = (reg >> 12) & ((1 << 36) - 1)
|
||||
# marker / other modes: no time advance
|
||||
if (reg & 0x100) == 0 and val36 != 0:
|
||||
# real marker: bit9=1, bit8=0, non-zero payload
|
||||
delta = 0
|
||||
note = f"0x16-marker val=0x{val:x}"
|
||||
note = f"0x16-marker val=0x{val36:x}"
|
||||
else:
|
||||
# "other" 0x16 variants, ignored for timing
|
||||
delta = 0
|
||||
note = "0x16-other"
|
||||
else:
|
||||
|
|
@ -387,8 +506,7 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000) -> None:
|
|||
extra = decode_packet_fields(opcode, reg, delta)
|
||||
if extra: note = (note + " ; " + extra) if note else extra
|
||||
|
||||
BORING_OPCODES = {0x11, 0x14}
|
||||
if opcode not in BORING_OPCODES or getenv("BORING", 1):
|
||||
if filter is None or opcode not in filter:
|
||||
my_reg = reg
|
||||
my_reg &= (1 << nib_budget) - 1
|
||||
print(
|
||||
|
|
|
|||
|
|
@ -111,7 +111,6 @@ def decode(profile:list[ProfileEvent]) -> _ROCParseCtx:
|
|||
|
||||
@rocprof.rocprof_trace_decoder_isa_callback_t
|
||||
def isa_cb(instr_ptr, mem_size_ptr, size_ptr, pc, data_ptr):
|
||||
if DEBUG >= 8: print(f"isa_cb {pc.address=} {pc.code_object_id=}")
|
||||
instr, mem_size_ptr[0] = ROCParseCtx.disasms[(unwrap(ROCParseCtx.active_kern), pc.address)]
|
||||
|
||||
# this is the number of bytes to next instruction, set to 0 for end_pgm
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ def helper_collect_profile(*devs):
|
|||
cpu_events.clear()
|
||||
|
||||
profile_list = []
|
||||
with Context(VIZ=1):
|
||||
with Context(VIZ=1, PROFILE=1):
|
||||
yield profile_list
|
||||
for dev in devs: dev.synchronize()
|
||||
for dev in devs: dev._at_profile_finalize()
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ class TestProgressBar(unittest.TestCase):
|
|||
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
|
||||
self._compare_bars(tinytqdm_output, tqdm_output)
|
||||
|
||||
@unittest.skip("this is flaky")
|
||||
@patch('sys.stderr', new_callable=StringIO)
|
||||
@patch('shutil.get_terminal_size')
|
||||
def test_unit_scale(self, mock_terminal_size, mock_stderr):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatch
|
|||
from tinygrad.uop.symbolic import sym
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import PROFILE, colored, ansistrip, flatten, TracingKey, ProfileRangeEvent, ProfileEvent, Context, cpu_events, profile_marker
|
||||
from tinygrad.helpers import VIZ
|
||||
from tinygrad.device import Buffer
|
||||
|
||||
@track_rewrites(name=True)
|
||||
|
|
@ -33,11 +34,14 @@ class BaseTestViz(unittest.TestCase):
|
|||
cpu_events.clear()
|
||||
self.tms = TRACK_MATCH_STATS.value
|
||||
self.profile = PROFILE.value
|
||||
self.viz = VIZ.value
|
||||
TRACK_MATCH_STATS.value = 2
|
||||
PROFILE.value = 1
|
||||
VIZ.value = 1
|
||||
def tearDown(self):
|
||||
TRACK_MATCH_STATS.value = self.tms
|
||||
PROFILE.value = self.profile
|
||||
VIZ.value = self.viz
|
||||
|
||||
class TestViz(BaseTestViz):
|
||||
def test_simple(self):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Any, Generic, TypeVar, Iterator, Sequence, cast, Generator
|
|||
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM
|
||||
from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
|
||||
from tinygrad.helpers import unwrap_class_type, suppress_finalizing, AMD_LLVM, select_first_inited
|
||||
from tinygrad.helpers import unwrap_class_type, suppress_finalizing, AMD_LLVM, select_first_inited, VIZ
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
|
|
@ -355,8 +355,9 @@ if PROFILE:
|
|||
|
||||
with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(cpu_events+Compiled.profile_events+Buffer.profile_events, f)
|
||||
|
||||
from tinygrad.uop.ops import launch_viz
|
||||
launch_viz("PROFILE", fn)
|
||||
if VIZ:
|
||||
from tinygrad.uop.ops import launch_viz
|
||||
launch_viz("PROFILE", fn)
|
||||
|
||||
def enumerate_devices_str() -> Generator[str, None, None]:
|
||||
from tinygrad import Tensor, Device
|
||||
|
|
|
|||
|
|
@ -179,7 +179,9 @@ ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), Conte
|
|||
EMULATE = ContextVar("EMULATE", "")
|
||||
CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else (os.cpu_count() or 1)))
|
||||
CPU_LLVM, CPU_LVP, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("CPU_LVP", 0), ContextVar("AMD_LLVM", 0)
|
||||
VIZ = PROFILE = ContextVar("VIZ", 0)
|
||||
# VIZ implies PROFILE, but you can run PROFILE without VIZ
|
||||
VIZ = ContextVar("VIZ", 0)
|
||||
PROFILE = ContextVar("PROFILE", VIZ.value)
|
||||
SPEC = ContextVar("SPEC", 1)
|
||||
# TODO: disable by default due to speed
|
||||
IGNORE_OOB = ContextVar("IGNORE_OOB", 1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue