mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
assembly/amd: SQTT support (#14099)
* assembly/amd: SQTT support * simpler * cmp wave * instruction compare * rocprof decode * simpler * no llvm * no strcmp
This commit is contained in:
parent
8b5ff403fa
commit
8b1b15aec0
3 changed files with 587 additions and 0 deletions
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
|
|
@ -680,6 +680,8 @@ jobs:
|
|||
sudo apt-get install llvm-21 llvm-21-tools cloc
|
||||
- name: RDNA3 Line Count
|
||||
run: cloc --by-file extra/assembly/amd/*.py
|
||||
- name: Install rocprof-trace-decoder
|
||||
run: sudo PYTHONPATH="." ./extra/sqtt/install_sqtt_decoder.py
|
||||
- name: Run RDNA3 emulator tests
|
||||
run: python -m pytest -n=auto extra/assembly/amd/ --durations 20
|
||||
- name: Run RDNA3 emulator tests (AMD_LLVM=1)
|
||||
|
|
|
|||
381
extra/assembly/amd/sqtt.py
Normal file
381
extra/assembly/amd/sqtt.py
Normal file
|
|
@ -0,0 +1,381 @@
|
|||
"""SQTT (SQ Thread Trace) packet encoder and decoder for AMD GPUs.
|
||||
|
||||
This module provides encoding and decoding of raw SQTT byte streams.
|
||||
The format is nibble-based with variable-width packets determined by a state machine.
|
||||
Uses BitField infrastructure from dsl.py, similar to GPU instruction encoding.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
from typing import get_type_hints
|
||||
from extra.assembly.amd.dsl import BitField, bits
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# FIELD ENUMS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class MemSrc(IntEnum):
|
||||
LDS = 0
|
||||
LDS_ALT = 1
|
||||
VMEM = 2
|
||||
VMEM_ALT = 3
|
||||
|
||||
class AluSrc(IntEnum):
|
||||
NONE = 0
|
||||
SALU = 1
|
||||
VALU = 2
|
||||
VALU_ALT = 3
|
||||
|
||||
class InstOp(IntEnum):
|
||||
"""SQTT instruction operation types.
|
||||
|
||||
Memory ops appear in two ranges depending on which SIMD executes them:
|
||||
- 0x1x-0x2x range: ops on traced SIMD
|
||||
- 0x5x range: ops on other SIMD (OTHER_ prefix)
|
||||
|
||||
GLOBAL memory ops encoding depends on addressing mode AND size:
|
||||
- Loads: 0x21 (saddr=SGPR) or 0x22 (saddr=NULL), all sizes same
|
||||
- Stores: base + size_offset, where VADDR is shifted +1 from SADDR
|
||||
SADDR: 0x24(32) 0x25(64) 0x26(96) 0x27(128)
|
||||
VADDR: 0x25(32) 0x26(64) 0x27(96) 0x28(128)
|
||||
|
||||
OTHER_ range follows same pattern but values overlap differently.
|
||||
"""
|
||||
SALU = 0x0
|
||||
SMEM = 0x1
|
||||
JUMP = 0x3 # branch taken
|
||||
JUMP_NO = 0x4 # branch not taken
|
||||
MESSAGE = 0x9
|
||||
VALU_TRANS = 0xb # transcendental: exp, log, rcp, sqrt, sin, cos
|
||||
VALU_64_SHIFT = 0xd # 64-bit shifts: lshl, lshr, ashr
|
||||
VALU_MAD64 = 0xe # 64-bit multiply-add
|
||||
VALU_64 = 0xf # 64-bit: add, mul, fma, rcp, sqrt, rounding, frexp, div helpers
|
||||
VINTERP = 0x12 # interpolation: v_interp_p10_f32, v_interp_p2_f32
|
||||
BARRIER = 0x13
|
||||
|
||||
# FLAT memory ops on traced SIMD (0x1x range)
|
||||
FLAT_LOAD = 0x1c
|
||||
FLAT_STORE = 0x1d
|
||||
FLAT_STORE_64 = 0x1e
|
||||
FLAT_STORE_96 = 0x1f
|
||||
FLAT_STORE_128 = 0x20
|
||||
|
||||
# GLOBAL memory ops on traced SIMD (0x2x range)
|
||||
GLOBAL_LOAD = 0x21 # saddr=SGPR, all sizes
|
||||
GLOBAL_LOAD_VADDR = 0x22 # saddr=NULL, all sizes
|
||||
GLOBAL_STORE = 0x24 # saddr=SGPR, 32-bit
|
||||
GLOBAL_STORE_64 = 0x25 # saddr=SGPR 64 or saddr=NULL 32
|
||||
GLOBAL_STORE_96 = 0x26 # saddr=SGPR 96 or saddr=NULL 64
|
||||
GLOBAL_STORE_128 = 0x27 # saddr=SGPR 128 or saddr=NULL 96
|
||||
GLOBAL_STORE_VADDR_128 = 0x28 # saddr=NULL, 128-bit
|
||||
|
||||
# LDS ops on traced SIMD
|
||||
LDS_LOAD = 0x29
|
||||
LDS_STORE = 0x2b
|
||||
LDS_STORE_64 = 0x2c
|
||||
LDS_STORE_128 = 0x2e
|
||||
|
||||
# Memory ops on other SIMD (0x5x range)
|
||||
OTHER_LDS_LOAD = 0x50
|
||||
OTHER_LDS_STORE = 0x51
|
||||
OTHER_LDS_STORE_64 = 0x52
|
||||
OTHER_LDS_STORE_128 = 0x54
|
||||
OTHER_FLAT_LOAD = 0x55
|
||||
OTHER_FLAT_STORE = 0x56
|
||||
OTHER_FLAT_STORE_64 = 0x57
|
||||
OTHER_FLAT_STORE_96 = 0x58
|
||||
OTHER_FLAT_STORE_128 = 0x59
|
||||
OTHER_GLOBAL_LOAD = 0x5a # saddr=SGPR, all sizes
|
||||
OTHER_GLOBAL_LOAD_VADDR = 0x5b # saddr=NULL or saddr=SGPR store 32
|
||||
OTHER_GLOBAL_STORE_64 = 0x5c # saddr=SGPR 64 or saddr=NULL 32
|
||||
OTHER_GLOBAL_STORE_96 = 0x5d # saddr=SGPR 96 or saddr=NULL 64
|
||||
OTHER_GLOBAL_STORE_128 = 0x5e # saddr=SGPR 128 or saddr=NULL 96
|
||||
OTHER_GLOBAL_STORE_VADDR_128 = 0x5f # saddr=NULL, 128-bit
|
||||
|
||||
# EXEC-modifying ops (0x7x range)
|
||||
SALU_SAVEEXEC = 0x72 # s_*_saveexec_b32/b64
|
||||
VALU_CMPX = 0x73 # v_cmpx_*
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PACKET TYPE BASE CLASS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class PacketType:
|
||||
"""Base class for SQTT packet types."""
|
||||
_encoding: tuple[BitField, int] | None = None
|
||||
_field_types: dict[str, type] = {}
|
||||
_values: dict[str, int]
|
||||
_raw: int
|
||||
_time: int
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if 'encoding' in cls.__dict__ and isinstance(cls.__dict__['encoding'], tuple):
|
||||
cls._encoding = cls.__dict__['encoding']
|
||||
# Cache field type annotations for enum conversion
|
||||
try: cls._field_types = {k: v for k, v in get_type_hints(cls).items() if isinstance(v, type) and issubclass(v, IntEnum)}
|
||||
except Exception: cls._field_types = {}
|
||||
# Cache fields and precompute extraction info: (name, lo, mask, enum_type)
|
||||
cls._fields = {k: v for k, v in cls.__dict__.items() if isinstance(v, BitField) and k != 'encoding'}
|
||||
cls._extract_info = [(name, bf.lo, bf.mask(), cls._field_types.get(name)) for name, bf in cls._fields.items()]
|
||||
cls._size_nibbles = ((max((f.hi for f in cls._fields.values()), default=0) + 4) // 4)
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, raw: int, time: int = 0):
|
||||
inst = object.__new__(cls)
|
||||
inst._raw, inst._time, inst._values = raw, time, {}
|
||||
for name, lo, mask, enum_type in cls._extract_info:
|
||||
val = (raw >> lo) & mask
|
||||
if enum_type is not None:
|
||||
try: val = enum_type(val)
|
||||
except ValueError: pass
|
||||
inst._values[name] = val
|
||||
return inst
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if name.startswith('_'): raise AttributeError(name)
|
||||
return self._values.get(name, 0)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
fields_str = ", ".join(f"{k}={v}" for k, v in self._values.items() if not k.startswith('_'))
|
||||
return f"{self.__class__.__name__}({fields_str})"
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PACKET TYPE DEFINITIONS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class VALUINST(PacketType): # exclude: 1 << 2
|
||||
encoding = bits[2:0] == 0b011
|
||||
delta = bits[5:3]
|
||||
flag = bits[6:6]
|
||||
wave = bits[11:7]
|
||||
|
||||
class VMEMEXEC(PacketType): # exclude: 1 << 0
|
||||
encoding = bits[3:0] == 0b1111
|
||||
delta = bits[5:4]
|
||||
src: MemSrc = bits[7:6]
|
||||
|
||||
class ALUEXEC(PacketType): # exclude: 1 << 1
|
||||
encoding = bits[3:0] == 0b1110
|
||||
delta = bits[5:4]
|
||||
src: AluSrc = bits[7:6]
|
||||
|
||||
class IMMEDIATE(PacketType): # exclude: 1 << 5
|
||||
encoding = bits[3:0] == 0b1101
|
||||
delta = bits[6:4]
|
||||
wave = bits[11:7]
|
||||
|
||||
class IMMEDIATE_MASK(PacketType): # exclude: 1 << 5
|
||||
encoding = bits[4:0] == 0b00100
|
||||
delta = bits[7:5]
|
||||
mask = bits[23:8]
|
||||
|
||||
class WAVERDY(PacketType): # exclude: 1 << 3
|
||||
encoding = bits[4:0] == 0b10100
|
||||
delta = bits[7:5]
|
||||
mask = bits[23:8]
|
||||
|
||||
class TS_DELTA_S8_W3(PacketType):
|
||||
encoding = bits[6:0] == 0b0100001
|
||||
delta = bits[10:8]
|
||||
_padding = bits[63:11]
|
||||
|
||||
class WAVEEND(PacketType): # exclude: 1 << 4
|
||||
encoding = bits[4:0] == 0b10101
|
||||
delta = bits[7:5]
|
||||
flag7 = bits[8:8]
|
||||
simd = bits[10:9]
|
||||
cu_lo = bits[13:11]
|
||||
wave = bits[19:15]
|
||||
@property
|
||||
def cu(self) -> int: return self.cu_lo | (self.flag7 << 3)
|
||||
|
||||
class WAVESTART(PacketType): # exclude: 1 << 4
|
||||
encoding = bits[4:0] == 0b01100
|
||||
delta = bits[6:5]
|
||||
flag7 = bits[7:7]
|
||||
simd = bits[9:8]
|
||||
cu_lo = bits[12:10]
|
||||
wave = bits[17:13]
|
||||
id7 = bits[31:18]
|
||||
@property
|
||||
def cu(self) -> int: return self.cu_lo | (self.flag7 << 3)
|
||||
|
||||
class TS_DELTA_S5_W2(PacketType):
|
||||
encoding = bits[4:0] == 0b11100
|
||||
delta = bits[6:5]
|
||||
_padding = bits[47:7]
|
||||
|
||||
class WAVEALLOC(PacketType): # exclude: 1 << 10
|
||||
encoding = bits[4:0] == 0b00101
|
||||
delta = bits[7:5]
|
||||
_padding = bits[19:8]
|
||||
|
||||
class TS_DELTA_S5_W3(PacketType):
|
||||
encoding = bits[4:0] == 0b00110
|
||||
delta = bits[7:5]
|
||||
_padding = bits[51:8]
|
||||
|
||||
class PERF(PacketType): # exclude: 1 << 11
|
||||
encoding = bits[4:0] == 0b10110
|
||||
delta = bits[7:5]
|
||||
arg = bits[27:8]
|
||||
|
||||
class TS_DELTA_SHORT(PacketType):
|
||||
encoding = bits[3:0] == 0b1000
|
||||
delta = bits[7:4]
|
||||
|
||||
class NOP(PacketType):
|
||||
encoding = bits[3:0] == 0b0000
|
||||
delta = None # type: ignore
|
||||
_padding = bits[3:0]
|
||||
|
||||
class TS_WAVE_STATE(PacketType):
|
||||
encoding = bits[6:0] == 0b1010001
|
||||
delta = bits[15:7]
|
||||
coarse = bits[23:16]
|
||||
@property
|
||||
def wave_interest(self) -> bool: return bool(self.coarse & 1)
|
||||
@property
|
||||
def terminate_all(self) -> bool: return bool(self.coarse & 8)
|
||||
|
||||
class EVENT(PacketType): # exclude: 1 << 7
|
||||
encoding = bits[7:0] == 0b01100001
|
||||
delta = bits[10:8]
|
||||
event = bits[23:11]
|
||||
|
||||
class EVENT_BIG(PacketType):
|
||||
encoding = bits[7:0] == 0b11100001
|
||||
delta = bits[10:8]
|
||||
event = bits[31:11]
|
||||
|
||||
class REG(PacketType):
|
||||
encoding = bits[3:0] == 0b1001
|
||||
delta = bits[6:4]
|
||||
slot = bits[9:7]
|
||||
hi_byte = bits[15:8]
|
||||
subop = bits[31:16]
|
||||
val32 = bits[63:32]
|
||||
@property
|
||||
def is_config(self) -> bool: return bool(self.hi_byte & 0x80)
|
||||
|
||||
class SNAPSHOT(PacketType):
|
||||
encoding = bits[6:0] == 0b1110001
|
||||
delta = bits[9:7]
|
||||
snap = bits[63:10]
|
||||
|
||||
class TS_DELTA_OR_MARK(PacketType):
|
||||
encoding = bits[6:0] == 0b0000001
|
||||
delta = bits[47:12]
|
||||
bit8 = bits[8:8]
|
||||
bit9 = bits[9:9]
|
||||
@property
|
||||
def is_marker(self) -> bool: return bool(self.bit9 and not self.bit8)
|
||||
|
||||
class LAYOUT_HEADER(PacketType):
|
||||
encoding = bits[6:0] == 0b0010001
|
||||
delta = None # type: ignore
|
||||
layout = bits[12:7]
|
||||
simd = bits[14:13]
|
||||
group = bits[17:15]
|
||||
sel_a = bits[31:28]
|
||||
sel_b = bits[36:33]
|
||||
flag4 = bits[59:59]
|
||||
_padding = bits[63:60]
|
||||
|
||||
class INST(PacketType):
|
||||
encoding = bits[2:0] == 0b010
|
||||
delta = bits[6:4]
|
||||
flag1 = bits[3:3]
|
||||
flag2 = bits[7:7]
|
||||
wave = bits[12:8]
|
||||
op: InstOp = bits[19:13]
|
||||
|
||||
class UTILCTR(PacketType):
|
||||
encoding = bits[6:0] == 0b0110001
|
||||
delta = bits[8:7]
|
||||
ctr = bits[47:9]
|
||||
|
||||
# All packet types in encoding priority order (more specific masks first, NOP last as fallback)
|
||||
PACKET_TYPES: list[type[PacketType]] = [
|
||||
EVENT, EVENT_BIG,
|
||||
TS_DELTA_S8_W3, TS_WAVE_STATE, SNAPSHOT, TS_DELTA_OR_MARK, LAYOUT_HEADER, UTILCTR,
|
||||
IMMEDIATE_MASK, WAVERDY, WAVEEND, WAVESTART, TS_DELTA_S5_W2, WAVEALLOC, TS_DELTA_S5_W3, PERF,
|
||||
VMEMEXEC, ALUEXEC, IMMEDIATE, TS_DELTA_SHORT, REG,
|
||||
VALUINST, INST,
|
||||
NOP,
|
||||
]
|
||||
|
||||
def _build_state_table() -> tuple[bytes, dict[int, type[PacketType]]]:
|
||||
table = [len(PACKET_TYPES) - 1] * 256 # default to NOP
|
||||
opcode_to_class: dict[int, type[PacketType]] = {i: cls for i, cls in enumerate(PACKET_TYPES)}
|
||||
|
||||
for byte_val in range(256):
|
||||
for opcode, pkt_cls in enumerate(PACKET_TYPES):
|
||||
if pkt_cls._encoding is None: continue
|
||||
mask_bf, pattern = pkt_cls._encoding
|
||||
if (byte_val & mask_bf.mask()) == pattern:
|
||||
table[byte_val] = opcode
|
||||
break
|
||||
|
||||
return bytes(table), opcode_to_class
|
||||
|
||||
STATE_TO_OPCODE, OPCODE_TO_CLASS = _build_state_table()
|
||||
|
||||
# Precompute special case opcodes
|
||||
_TS_DELTA_OR_MARK_OPCODE = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is TS_DELTA_OR_MARK)
|
||||
_TS_DELTA_SHORT_OPCODE = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is TS_DELTA_SHORT)
|
||||
_TS_DELTA_OR_MARK_BIT8 = (TS_DELTA_OR_MARK.bit8.lo, TS_DELTA_OR_MARK.bit8.mask())
|
||||
_TS_DELTA_OR_MARK_BIT9 = (TS_DELTA_OR_MARK.bit9.lo, TS_DELTA_OR_MARK.bit9.mask())
|
||||
|
||||
# Combined lookup: opcode -> (pkt_cls, nib_count, delta_lo, delta_mask, special_case)
|
||||
# special_case: 0=none, 1=TS_DELTA_OR_MARK, 2=TS_DELTA_SHORT
|
||||
_DECODE_INFO: dict[int, tuple] = {}
|
||||
for _opcode, _pkt_cls in OPCODE_TO_CLASS.items():
|
||||
_delta_field = getattr(_pkt_cls, 'delta', None)
|
||||
_delta_lo = _delta_field.lo if _delta_field else 0
|
||||
_delta_mask = _delta_field.mask() if _delta_field else 0
|
||||
_special = 1 if _opcode == _TS_DELTA_OR_MARK_OPCODE else (2 if _opcode == _TS_DELTA_SHORT_OPCODE else 0)
|
||||
_DECODE_INFO[_opcode] = (_pkt_cls, _pkt_cls._size_nibbles, _delta_lo, _delta_mask, _special)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# DECODER
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def decode(data: bytes) -> list[PacketType]:
|
||||
"""Decode raw SQTT blob into list of packet instances."""
|
||||
packets: list[PacketType] = []
|
||||
packets_append = packets.append
|
||||
n = len(data)
|
||||
reg = 0
|
||||
offset = 0
|
||||
nib_count = 16
|
||||
time = 0
|
||||
state_to_opcode = STATE_TO_OPCODE
|
||||
decode_info = _DECODE_INFO
|
||||
mask64 = (1 << 64) - 1
|
||||
|
||||
while (offset >> 3) < n:
|
||||
target = offset + nib_count * 4
|
||||
while offset < target and (offset >> 3) < n:
|
||||
byte = data[offset >> 3]
|
||||
nib = (byte >> (offset & 4)) & 0xF
|
||||
reg = ((reg >> 4) | (nib << 60)) & mask64
|
||||
offset += 4
|
||||
if offset < target: break
|
||||
|
||||
opcode = state_to_opcode[reg & 0xFF]
|
||||
pkt_cls, nib_count, delta_lo, delta_mask, special = decode_info[opcode]
|
||||
|
||||
delta = (reg >> delta_lo) & delta_mask
|
||||
|
||||
if special == 1: # TS_DELTA_OR_MARK
|
||||
bit8 = (reg >> _TS_DELTA_OR_MARK_BIT8[0]) & _TS_DELTA_OR_MARK_BIT8[1]
|
||||
bit9 = (reg >> _TS_DELTA_OR_MARK_BIT9[0]) & _TS_DELTA_OR_MARK_BIT9[1]
|
||||
if bit9 and not bit8: delta = 0
|
||||
elif special == 2: # TS_DELTA_SHORT
|
||||
delta = delta + 8
|
||||
|
||||
time += delta
|
||||
packets_append(pkt_cls.from_raw(reg, time))
|
||||
|
||||
return packets
|
||||
204
extra/assembly/amd/test/test_sqtt_examples.py
Normal file
204
extra/assembly/amd/test/test_sqtt_examples.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Tests for SQTT packet decoding using real captured examples."""
|
||||
import pickle, unittest, ctypes
|
||||
from pathlib import Path
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
from tinygrad.runtime.autogen import rocprof
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
from extra.assembly.amd.asm import detect_format, disasm
|
||||
from extra.assembly.amd.autogen.rdna3.ins import SOPP
|
||||
from extra.assembly.amd.autogen.rdna3.enum import SOPPOp
|
||||
from extra.assembly.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, VALUINST, IMMEDIATE, IMMEDIATE_MASK,
|
||||
ALUEXEC, VMEMEXEC, PACKET_TYPES, InstOp, AluSrc, MemSrc)
|
||||
|
||||
EXAMPLES_DIR = Path(__file__).parent.parent.parent.parent / "sqtt/examples"
|
||||
OTHER_OP_RANGE = range(0x50, 0x60) # INST ops for non-traced SIMDs
|
||||
|
||||
PACKET_COLORS = {
|
||||
"INST": "WHITE", "VALUINST": "BLACK", "VMEMEXEC": "yellow", "ALUEXEC": "yellow",
|
||||
"IMMEDIATE": "YELLOW", "IMMEDIATE_MASK": "YELLOW", "WAVERDY": "cyan", "WAVEALLOC": "cyan",
|
||||
"WAVEEND": "blue", "WAVESTART": "blue", "PERF": "magenta", "EVENT": "red", "EVENT_BIG": "red",
|
||||
"REG": "green", "LAYOUT_HEADER": "white", "SNAPSHOT": "white", "UTILCTR": "green",
|
||||
}
|
||||
|
||||
def format_packet(p, time_offset: int = 0) -> str:
|
||||
name, cycle = type(p).__name__, p._time - time_offset
|
||||
if isinstance(p, INST):
|
||||
op_name = p.op.name if isinstance(p.op, InstOp) else f"0x{p.op:02x}"
|
||||
fields = f"wave={p.wave} op={op_name}" + (" flag1" if p.flag1 else "") + (" flag2" if p.flag2 else "")
|
||||
elif isinstance(p, VALUINST): fields = f"wave={p.wave}" + (" flag" if p.flag else "")
|
||||
elif isinstance(p, ALUEXEC): fields = f"src={p.src.name if isinstance(p.src, AluSrc) else p.src}"
|
||||
elif isinstance(p, VMEMEXEC): fields = f"src={p.src.name if isinstance(p.src, MemSrc) else p.src}"
|
||||
elif isinstance(p, (WAVESTART, WAVEEND)): fields = f"wave={p.wave} simd={p.simd} cu={p.cu}"
|
||||
elif hasattr(p, '_values'):
|
||||
fields = " ".join(f"{k}=0x{v:x}" if k in {'snap', 'val32'} else f"{k}={v}"
|
||||
for k, v in p._values.items() if not k.startswith('_') and k != 'delta')
|
||||
else: fields = ""
|
||||
return f"{cycle:8}: {colored(f'{name:18}', PACKET_COLORS.get(name, 'white'))} {fields}"
|
||||
|
||||
def print_packets(packets: list) -> None:
|
||||
skip = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3", "REG", "EVENT"}
|
||||
time_offset = packets[0]._time if packets else 0
|
||||
for p in packets:
|
||||
if type(p).__name__ not in skip: print(format_packet(p, time_offset))
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# ROCPROF DECODER
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int):
|
||||
"""Run rocprof decoder on SQTT blobs, returning raw occupancy and instruction records."""
|
||||
image, sections, _ = elf_loader(lib)
|
||||
text = next((sh for sh in sections if sh.name == ".text"), None)
|
||||
assert text is not None, "no .text section found"
|
||||
text_off, text_size = text.header.sh_addr, text.header.sh_size
|
||||
|
||||
blob_iter, current_blob = iter(blobs), [None]
|
||||
occupancy_records: list[tuple[int, int, int, int, bool]] = [] # (wave_id, simd, cu, time, is_start)
|
||||
wave_insts: list[list[tuple[int, int]]] = [] # per-wave list of (time, stall)
|
||||
|
||||
@rocprof.rocprof_trace_decoder_se_data_callback_t
|
||||
def copy_cb(buf, buf_size, _):
|
||||
blob = next(blob_iter, None)
|
||||
if blob is None: return 0
|
||||
current_blob[0] = (ctypes.c_ubyte * len(blob)).from_buffer_copy(blob)
|
||||
buf[0] = ctypes.cast(current_blob[0], ctypes.POINTER(ctypes.c_ubyte))
|
||||
buf_size[0] = len(current_blob[0])
|
||||
return len(current_blob[0])
|
||||
|
||||
@rocprof.rocprof_trace_decoder_trace_callback_t
|
||||
def trace_cb(record_type, events_ptr, n, _):
|
||||
if record_type == rocprof.ROCPROFILER_THREAD_TRACE_DECODER_RECORD_OCCUPANCY:
|
||||
for ev in (rocprof.rocprofiler_thread_trace_decoder_occupancy_t * n).from_address(events_ptr):
|
||||
occupancy_records.append((ev.wave_id, ev.simd, ev.cu, ev.time, ev.start))
|
||||
elif record_type == rocprof.ROCPROFILER_THREAD_TRACE_DECODER_RECORD_WAVE:
|
||||
for ev in (rocprof.rocprofiler_thread_trace_decoder_wave_t * n).from_address(events_ptr):
|
||||
if ev.instructions_size > 0:
|
||||
sz = ev.instructions_size * ctypes.sizeof(rocprof.rocprofiler_thread_trace_decoder_inst_t)
|
||||
insts_blob = bytearray(sz)
|
||||
ctypes.memmove((ctypes.c_char * sz).from_buffer(insts_blob), ev.instructions_array, sz)
|
||||
insts = list((rocprof.rocprofiler_thread_trace_decoder_inst_t * ev.instructions_size).from_buffer(insts_blob))
|
||||
wave_insts.append([(inst.time, inst.stall) for inst in insts])
|
||||
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
|
||||
|
||||
@rocprof.rocprof_trace_decoder_isa_callback_t
|
||||
def isa_cb(instr_ptr, mem_size_ptr, size_ptr, pc, _):
|
||||
offset = pc.address - base
|
||||
if offset < text_off or offset >= text_off + text_size:
|
||||
mem_size_ptr[0] = 0
|
||||
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
|
||||
try:
|
||||
fmt = detect_format(bytes(image[offset:]))
|
||||
inst = fmt.from_bytes(bytes(image[offset:]))
|
||||
instr_text, mem_size_ptr[0] = disasm(inst), inst._size()
|
||||
except (ValueError, AssertionError):
|
||||
mem_size_ptr[0] = 0
|
||||
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
|
||||
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_ENDPGM: mem_size_ptr[0] = 0
|
||||
if (max_sz := size_ptr[0]) == 0: return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_ERROR_OUT_OF_RESOURCES
|
||||
instr_bytes = instr_text.encode()
|
||||
ctypes.memmove(instr_ptr, instr_bytes, min(len(instr_bytes), max_sz - 1))
|
||||
size_ptr[0] = min(len(instr_bytes), max_sz - 1)
|
||||
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
|
||||
|
||||
rocprof.rocprof_trace_decoder_parse_data(copy_cb, trace_cb, isa_cb, None)
|
||||
return occupancy_records, wave_insts
|
||||
|
||||
class TestSQTTExamples(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.examples = {}
|
||||
for pkl_path in sorted(EXAMPLES_DIR.glob("*.pkl")):
|
||||
with open(pkl_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"]
|
||||
prg = next((e for e in data if type(e).__name__ == "ProfileProgramEvent"), None)
|
||||
if sqtt_events and prg:
|
||||
cls.examples[pkl_path.stem] = (sqtt_events, prg.lib, prg.base)
|
||||
|
||||
def test_examples_loaded(self):
|
||||
self.assertGreater(len(self.examples), 0, "no example files found")
|
||||
|
||||
def test_decode_all_examples(self):
|
||||
for name, (events, *_) in self.examples.items():
|
||||
for i, event in enumerate(events):
|
||||
with self.subTest(example=name, event=i):
|
||||
packets = decode(event.blob)
|
||||
if DEBUG >= 2: print(f"\n=== {name} event {i} ==="); print_packets(packets)
|
||||
self.assertGreater(len(packets), 0, f"no packets decoded from {name} event {i}")
|
||||
self.assertIsInstance(packets[0], LAYOUT_HEADER, f"first packet should be LAYOUT_HEADER in {name}")
|
||||
|
||||
def test_packet_types_valid(self):
|
||||
for name, (events, *_) in self.examples.items():
|
||||
for i, event in enumerate(events):
|
||||
with self.subTest(example=name, event=i):
|
||||
for pkt in decode(event.blob):
|
||||
self.assertIn(type(pkt), PACKET_TYPES, f"unknown packet type {type(pkt)} in {name}")
|
||||
|
||||
def test_wave_lifecycle(self):
|
||||
for name, (events, *_) in self.examples.items():
|
||||
if "empty" in name: continue
|
||||
with self.subTest(example=name):
|
||||
all_packets = [p for e in events for p in decode(e.blob)]
|
||||
self.assertGreater(len([p for p in all_packets if isinstance(p, WAVESTART)]), 0, f"no WAVESTART in {name}")
|
||||
self.assertGreater(len([p for p in all_packets if isinstance(p, WAVEEND)]), 0, f"no WAVEEND in {name}")
|
||||
|
||||
def test_time_monotonic(self):
|
||||
for name, (events, *_) in self.examples.items():
|
||||
for i, event in enumerate(events):
|
||||
with self.subTest(example=name, event=i):
|
||||
times = [p._time for p in decode(event.blob)]
|
||||
self.assertEqual(times, sorted(times), f"timestamps not monotonic in {name}")
|
||||
|
||||
def test_gemm_has_instructions(self):
|
||||
for name, (events, *_) in self.examples.items():
|
||||
if "gemm" not in name: continue
|
||||
with self.subTest(example=name):
|
||||
all_packets = [p for e in events for p in decode(e.blob)]
|
||||
self.assertGreater(len([p for p in all_packets if isinstance(p, INST)]), 0, f"no INST packets in {name}")
|
||||
|
||||
def test_rocprof_wave_times_match(self):
|
||||
"""Wave start/end times must match rocprof exactly."""
|
||||
for name, (events, lib, base) in self.examples.items():
|
||||
with self.subTest(example=name):
|
||||
occupancy, _ = run_rocprof_decoder([e.blob for e in events], lib, base)
|
||||
# extract from rocprof occupancy records
|
||||
roc_starts: dict[tuple[int, int, int], int] = {}
|
||||
roc_waves: list[tuple[int, int]] = []
|
||||
for wave_id, simd, cu, time, is_start in occupancy:
|
||||
key = (wave_id, simd, cu)
|
||||
if is_start: roc_starts[key] = time
|
||||
elif key in roc_starts: roc_waves.append((roc_starts.pop(key), time))
|
||||
# extract from our decoder
|
||||
our_waves: list[tuple[int, int]] = []
|
||||
for event in events:
|
||||
packets = decode(event.blob)
|
||||
wave_starts: dict[tuple[int, int, int], int] = {}
|
||||
for p in packets:
|
||||
if isinstance(p, WAVESTART): wave_starts[(p.wave, p.simd, p.cu)] = p._time
|
||||
elif isinstance(p, WAVEEND) and (key := (p.wave, p.simd, p.cu)) in wave_starts:
|
||||
our_waves.append((wave_starts[key], p._time))
|
||||
self.assertEqual(sorted(our_waves), sorted(roc_waves), f"wave times mismatch in {name}")
|
||||
|
||||
def test_rocprof_inst_times_match(self):
|
||||
"""Instruction times must match rocprof exactly (excluding s_endpgm)."""
|
||||
for name, (events, lib, base) in self.examples.items():
|
||||
with self.subTest(example=name):
|
||||
_, wave_insts = run_rocprof_decoder([e.blob for e in events], lib, base)
|
||||
# skip last inst per wave (s_endpgm) - it needs special handling (time + duration instead of time + stall)
|
||||
roc_insts = [time + stall for insts in wave_insts for time, stall in insts[:-1]]
|
||||
# extract from our decoder
|
||||
our_insts: list[int] = []
|
||||
for event in events:
|
||||
for p in decode(event.blob):
|
||||
if isinstance(p, INST):
|
||||
op_val = p.op if isinstance(p.op, int) else p.op.value
|
||||
if op_val not in OTHER_OP_RANGE: our_insts.append(p._time)
|
||||
elif isinstance(p, VALUINST): our_insts.append(p._time)
|
||||
elif isinstance(p, IMMEDIATE): our_insts.append(p._time)
|
||||
elif isinstance(p, IMMEDIATE_MASK):
|
||||
for _ in range(bin(p.mask).count('1')): our_insts.append(p._time)
|
||||
self.assertEqual(sorted(our_insts), sorted(roc_insts), f"instruction times mismatch in {name}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue