Compare commits

...

83 commits

Author SHA1 Message Date
George Hotz
2f85319722 Merge remote-tracking branch 'origin/master' into amd_sqtt
# Conflicts:
#	extra/assembly/amd/emu.py
#	extra/assembly/amd/sqtt.py
2026-01-12 05:42:04 +09:00
George Hotz
d9f0e9c40c something 2026-01-12 02:26:31 +09:00
George Hotz
3dcffbea25 NO SLOT 2026-01-11 17:31:50 +09:00
George Hotz
41c5368266 close 2026-01-11 17:26:20 +09:00
George Hotz
4598a21f94 rdna3 timing 2026-01-11 07:14:39 +00:00
George Hotz
27084cd618 some 2026-01-11 16:12:16 +09:00
George Hotz
d2616e5daf weird forward beavhior 2026-01-11 16:00:21 +09:00
George Hotz
3130c53f85 strange hardware behavior 2026-01-11 15:53:40 +09:00
George Hotz
1c66e41383 new test 2026-01-11 15:23:37 +09:00
George Hotz
93823b272c DEBUG=3 is pretty 2026-01-11 15:05:29 +09:00
George Hotz
1c6147e9bf better 2026-01-11 14:55:41 +09:00
George Hotz
7f5656d236 cold chain 2026-01-11 14:33:39 +09:00
George Hotz
5f55a61700 dumb 2026-01-11 13:22:39 +09:00
George Hotz
ed097df864 cleaner 2026-01-11 13:14:10 +09:00
George Hotz
a83c97f17e sqtt correct 2026-01-11 13:12:51 +09:00
George Hotz
c793076fb6 add s_delay_alu tests 2026-01-11 11:28:21 +09:00
George Hotz
1f45601a97 tests with early nops 2026-01-11 11:16:25 +09:00
George Hotz
14c4989f65 pipeline exec 2026-01-11 11:11:52 +09:00
George Hotz
31b38640ac nop anomaly 2026-01-11 10:41:13 +09:00
George Hotz
fe770e822c pats 2026-01-11 09:56:05 +09:00
George Hotz
768231c065 lat tests 2026-01-11 09:54:32 +09:00
George Hotz
4165594b30 first cycle lat 2026-01-11 09:48:04 +09:00
George Hotz
c03b7b0da1 gap5 anomaly 2026-01-11 09:13:51 +09:00
George Hotz
66249836c0 good test 2026-01-11 09:11:36 +09:00
George Hotz
a0d6ed9914 a couple more 2026-01-11 09:06:32 +09:00
George Hotz
99fcfc0e97 cleaner 2026-01-11 08:57:19 +09:00
George Hotz
cf8bb15aef padding 2026-01-11 08:55:51 +09:00
George Hotz
32dfc9b1d0 another test 2026-01-11 08:49:20 +09:00
George Hotz
9803e389fe good tests 2026-01-11 08:29:39 +09:00
George Hotz
1f893b65cc new hw free test 2026-01-11 07:49:50 +09:00
George Hotz
35f5f05ad5 multiwave 2026-01-09 21:01:26 -08:00
George Hotz
b9f08ad18a fix multiwave 2026-01-09 21:01:26 -08:00
George Hotz
222ae38aa4 fix multiwave 2026-01-09 21:01:26 -08:00
George Hotz
f0bf20d7b2 structuring 2026-01-09 21:01:26 -08:00
George Hotz
85ef097da6 snop passes 2026-01-09 21:01:26 -08:00
George Hotz
0e240fb987
Merge branch 'master' into amd_sqtt 2026-01-02 20:30:16 -05:00
George Hotz
d2c1712e4c more tests 2026-01-02 17:29:48 -08:00
George Hotz
96b0ee0966 lil 2026-01-02 16:53:31 -08:00
George Hotz
9b5c4bc698 shorter 2026-01-02 16:48:26 -08:00
George Hotz
6ea3586101 short 2026-01-02 16:45:34 -08:00
George Hotz
92cb8b6776 tests pass 2026-01-02 16:43:03 -08:00
George Hotz
c416b20668 failures 2026-01-02 15:54:02 -08:00
George Hotz
415b83ba18 tests pass 2026-01-02 15:47:39 -08:00
George Hotz
8c7eacea59 getting close 2026-01-02 15:25:18 -08:00
George Hotz
81542699f8 work 2026-01-02 14:39:52 -08:00
George Hotz
79f55a5d5e test_snop is correct 2026-01-02 12:01:08 -08:00
George Hotz
37518fb236 start with nop 2026-01-02 11:40:08 -08:00
George Hotz
672008ccab framework 2026-01-02 11:31:41 -08:00
George Hotz
849af761a4 simpler 2026-01-02 11:10:40 -08:00
George Hotz
ab46b3d8d3 origin/master 2026-01-02 10:47:00 -08:00
George Hotz
df20197bfb rever emu to master 2026-01-02 10:46:46 -08:00
George Hotz
2b56c264d5 compare tests 2026-01-02 10:39:07 -08:00
George Hotz
c7e5c2f996 Merge origin/master, remove deleted test_emu.py 2026-01-02 09:41:34 -08:00
George Hotz
659aa14043 orks 2026-01-02 05:29:48 -08:00
George Hotz
21ffa1a86b 64 nops 2026-01-02 00:38:27 -05:00
George Hotz
29f3fb7af3 still stable 2026-01-01 23:45:19 -05:00
George Hotz
1edc7fc519 stable 2026-01-01 23:43:43 -05:00
George Hotz
c9a3ac988c cleanest 2026-01-01 23:18:19 -05:00
George Hotz
77d96acbe3 clean 2026-01-01 22:59:07 -05:00
George Hotz
660ecf272b work 2026-01-01 22:50:50 -05:00
George Hotz
267bbb163e progress 2026-01-01 21:11:29 -05:00
George Hotz
de29a49ea3 all the ones i can find 2026-01-01 20:56:30 -05:00
George Hotz
742e10a572 remove fake ones 2026-01-01 20:26:53 -05:00
George Hotz
447fe8907b more 2026-01-01 20:22:52 -05:00
George Hotz
b0cfcec183 good 2026-01-01 20:12:20 -05:00
George Hotz
1726084b2a filt 2026-01-01 19:40:43 -05:00
George Hotz
de069a4876 many 2026-01-01 19:21:46 -05:00
George Hotz
4573e91e61 more 2026-01-01 18:51:31 -05:00
George Hotz
8d43212bc6 assembly/amd: start work on SQTT parsing/emulation 2026-01-01 18:40:58 -05:00
George Hotz
a8bea4ec52 remove __all__ 2026-01-01 16:14:15 -05:00
George Hotz
388514c5b1 better 2026-01-01 16:03:29 -05:00
George Hotz
729bb04d8c fix test failure 2026-01-01 13:21:55 -05:00
George Hotz
8f4de73141 two tests 2026-01-01 13:13:01 -05:00
George Hotz
a5959ef0f1 fix all tests 2026-01-01 13:11:51 -05:00
George Hotz
5ba06892c0 generic 2026-01-01 12:46:08 -05:00
George Hotz
469efe313d that's a hack 2026-01-01 12:40:14 -05:00
George Hotz
e3b3cb163d fix emu test 2026-01-01 12:12:47 -05:00
George Hotz
3e32185faf more tests 2026-01-01 12:04:41 -05:00
George Hotz
5328913d2b fix flat bug 2026-01-01 11:51:10 -05:00
George Hotz
9c49ec1cc1 update autogen 2026-01-01 11:36:33 -05:00
George Hotz
000d4a125b fix ds op 2026-01-01 10:36:37 -05:00
George Hotz
63289902d8 refactors 2025-12-31 17:57:27 -05:00
George Hotz
b596f77e33 assembly/amd: add pcode ds ops 2025-12-31 16:59:02 -05:00
11 changed files with 5848 additions and 8 deletions

View file

@ -2,12 +2,12 @@
# mypy: ignore-errors
from __future__ import annotations
import ctypes, functools
from tinygrad.helpers import DEBUG, colored, ansilen
from tinygrad.runtime.autogen import hsa
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.pcode import compile_pseudocode
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64, SrcEnum
from extra.assembly.amd.pcode import Reg, compile_pseudocode
from extra.assembly.amd.asm import detect_format, disasm
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
from extra.assembly.amd.dsl import SrcEnum
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
@ -329,6 +329,294 @@ def exec_wmma(st: WaveState, inst, op: VOP3POp) -> None:
else:
for i in range(256): st.vgpr[i % 32][vdst + i//32] = _i32(mat_d[i])
# SQTT TRACING
# ═══════════════════════════════════════════════════════════════════════════════
WAVESTART_TO_INST_CYCLES = 32
SNOP_EXTRA_DELAY_MIN, SNOP_EXTRA_DELAY_MAX = 11, 22 # s_nop(11-22) has +4 penalty
SNOP_EXTRA_DELAY_CYCLES = 4
from extra.assembly.amd.sqtt import WAVESTART, WAVEEND, IMMEDIATE, VALUINST, ALUEXEC, AluSrc
def _get_src_vgprs(inst: Inst) -> list[int]:
if isinstance(inst, VOP1): return [inst.src0 - 256] if inst.src0 >= 256 else []
if isinstance(inst, VOP2): return ([inst.src0 - 256] if inst.src0 >= 256 else []) + [inst.vsrc1]
if isinstance(inst, VOP3): return [s - 256 for s in [inst.src0, inst.src1, getattr(inst, 'src2', None)] if s is not None and s >= 256]
return []
class SQTTState:
"""SQTT tracing with cycle-accurate RDNA3 VALU pipeline model.
NOTE: This is a hardware-plausible model derived from observed SQTT timing patterns.
The model should be verified by tests against real hardware traces, not by fitting
formulas to expected outputs. If tests fail, the model needs to be understood and
fixed, not hacked with magic constants.
Physical model:
- alu[4]: 4-stage ALU pipeline, each slot holds dest_vgpr or None
- in_flight: up to 12 in-flight instructions (issued but not yet completed)
- issue_queue: instructions waiting to enter ALU (sources not ready)
- fwd_slots: 4 forwarding slots, reserved at issue, freed when consumer forwards
- completed: vgprs with results ready (exited ALU)
Forwarding model (4 slots):
- Slot reserved at ISSUE time if available (len(fwd_slots) < 4)
- Slot freed when a consumer uses the result for forwarding
- Consumer can forward if: has a slot AND producer is completed
- If no slot at issue, instruction uses regfile path (+4 cycle penalty)
"""
def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0):
self.wave_id, self.simd, self.cu = wave_id, simd, cu
self.cycle = 0
self.packets = []
# 4-stage ALU pipeline: each slot holds dest_vgpr or None
self.alu = [None, None, None, None]
# In-flight instructions: max 12 at a time, each is (dest_vgpr, srcs, has_fwd_slot)
self.in_flight: list[tuple[int, list[int], bool]] = []
# Issue queue: list of (dest_vgpr, srcs, ready_at, has_fwd_slot, was_warm) waiting for deps
# ready_at: cycle when this instruction can enter ALU (0 = no restriction)
# has_fwd_slot: True if this instruction reserved a forwarding slot at issue time
# was_warm: True if forwarding path was warm when this instruction was issued
self.issue_queue: list[tuple[int, list[int], int, bool, bool]] = []
# 4 forwarding slots: consumer adds producer at issue, freed when consumer forwards
self.fwd_slots: list[int] = [] # producer vgprs reserved for forwarding
# VGPRs that had a dependent try to add them to fwd_slots (successful or not)
self.had_dependent: set[int] = set()
# VGPRs that were issued after forwarding chain broke (can't forward)
self.fwd_chain_broken: set[int] = set()
# Set of completed vgprs (results ready, exited ALU)
self.completed: set[int] = set()
# Cold start: first forwarding use has +1 cycle penalty
self.forward_warm = False
self.cold_used = False # True if cold start penalty was applied
def emit(self, pkt_class, **kwargs):
self.packets.append(pkt_class(_time=self.cycle, **kwargs))
def _fmt_alu(self) -> str:
# Fixed width: each slot 3 chars, total ALU[xxx,xxx,xxx,xxx] = 20 chars
slots = [f'v{v}' if v is not None else '-' for v in self.alu]
return 'ALU[' + ','.join(f'{s:>3}' for s in slots) + ']'
def _fmt_fwd(self) -> str:
items = [f'v{v}' for v in self.fwd_slots]
content = 'FWD[' + ','.join(items) + ']' if items else 'FWD[]'
padded = f'{content:<24}'
return colored(padded, 'yellow') if items else padded
def _fmt_iq(self) -> str:
def fmt_item(d, r, fwd):
s = f'v{d}'
if r != 0: s += f'@{abs(r)}'
if not fwd: s += 'R'
return s
items = [fmt_item(d, r, fwd) for d, _, r, fwd, _ in self.issue_queue]
return 'IQ[' + ','.join(items) + ']' if items else 'IQ[]'
def _debug_line(self, events: list[str] | None = None):
if DEBUG < 3: return
# Skip empty cycles (nothing in ALU, no events, no IQ)
has_alu = any(s is not None for s in self.alu)
if not has_alu and not events and not self.issue_queue: return
cycle = colored(f'C{self.cycle:>3}:', 'cyan')
alu = self._fmt_alu()
fwd = self._fmt_fwd()
iq = f'{self._fmt_iq():<28}'
ev_str = ' '.join(events) if events else ''
ev_padded = f'{ev_str:<20}' if ev_str else ' ' * 20
print(f"{cycle} {alu} {fwd} {iq} {ev_padded}")
def _can_issue(self) -> bool:
return len(self.in_flight) < 12
def _has_pending_write(self, vgpr: int) -> bool:
"""Check if there's a pending write to this VGPR (in ALU, in-flight, or issue queue)."""
if any(slot == vgpr for slot in self.alu if slot is not None): return True
if any(d == vgpr for d, _, _ in self.in_flight): return True
if any(d == vgpr for d, _, _, _, _ in self.issue_queue): return True
return False
def _all_srcs_ready(self, srcs: list[int]) -> bool:
"""Returns True if all sources are ready (completed or no pending write)."""
for src in srcs:
if src in self.completed: continue
if not self._has_pending_write(src): continue # initial value
return False
return True
def tick(self):
self.cycle += 1
if self.cycle > 10000: raise RuntimeError("cycle limit exceeded")
events = []
# 1. ALU[3] exits - capture but don't add to completed yet
exiting = self.alu[3]
if exiting is not None:
self.emit(ALUEXEC, src=AluSrc.VALU)
events.append(colored(f"EXEC v{exiting}", 'red'))
# 2. Slide ALU pipeline
self.alu[3] = self.alu[2]
self.alu[2] = self.alu[1]
self.alu[1] = self.alu[0]
self.alu[0] = None
# 3. Try to promote from issue_queue to ALU[0] (before adding exiting to completed)
if self.alu[0] is None and self.issue_queue:
for i, (dest, srcs, ready_at, has_fwd_slot, was_warm) in enumerate(self.issue_queue):
# Check if instruction has a minimum ready cycle
if ready_at > 0 and self.cycle < ready_at:
continue
# Check if sources are ready
ready = self._all_srcs_ready(srcs)
has_deps = len(srcs) > 0
if not ready:
continue
# Cold start penalty: first dependent instruction has +1 cycle delay (delta=6 vs delta=5)
# Only applies if forwarding path wasn't warm when this instruction was issued
if has_deps and not was_warm and not self.cold_used:
self.cold_used = True
self.issue_queue[i] = (dest, srcs, self.cycle + 1, has_fwd_slot, was_warm)
continue
# Forwarding: consumer can forward if:
# 1. Not in fwd_chain_broken (chain must be intact), AND
# 2. Producer has a slot (source is in fwd_slots), AND
# 3. Either activated by dependent OR successfully added producer at issue
# Note: if issued cold with no slot, activation only counts if the activator also has a dependent
chain_intact = dest not in self.fwd_chain_broken
producer_has_slot = has_deps and any(src in self.fwd_slots for src in srcs)
# Check activation validity
if dest in self.had_dependent:
if was_warm or has_fwd_slot:
activated_by_dependent = True
else:
# Cold + no slot: activation only counts if activator itself has a dependent
# This handles the chain_6 vs chain_7 difference (chain_7 has v6 which activates v5)
activated_by_dependent = (dest + 1) in self.had_dependent # activator is dest+1 in a chain
else:
activated_by_dependent = False
can_forward = chain_intact and producer_has_slot and (activated_by_dependent or has_fwd_slot)
# Regfile path: has dependencies but can't forward
must_use_regfile = has_deps and not can_forward
# Regfile penalty: add +4 cycles latency (only apply once)
if must_use_regfile and ready_at == 0:
self.issue_queue[i] = (dest, srcs, self.cycle + 4, has_fwd_slot, was_warm)
continue
# Enter ALU
self.alu[0] = dest
self.issue_queue.pop(i)
# Free producer's forwarding slot when consumer dispatches (regardless of fwd/rf)
for src in srcs:
if src in self.fwd_slots:
self.fwd_slots.remove(src)
break
events.append(colored(f"v{dest}->ALU" + ("(fwd)" if can_forward else "(rf)" if must_use_regfile else ""), 'green'))
break
# 4. Now add exiting instruction to completed (after promotion decision)
if exiting is not None:
self.completed.add(exiting)
# Remove from in_flight - any VALU completing warms up the forward path
for idx, (d, _, _) in enumerate(self.in_flight):
if d == exiting:
self.forward_warm = True
self.in_flight.pop(idx)
break
self._debug_line(events)
def _pipeline_empty(self) -> bool:
if any(s is not None for s in self.alu): return False
if self.issue_queue: return False
if self.in_flight: return False
return True
def process_instruction(self, inst: Inst):
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_DELAY_ALU:
# TODO: implement s_delay_alu properly
return
elif isinstance(inst, SOPP) and inst.op == SOPPOp.S_NOP:
# s_nop(N) delays N+1 cycles, plus extra penalty for s_nop(11-22)
cycles = inst.simm16 + 1
if SNOP_EXTRA_DELAY_MIN <= inst.simm16 <= SNOP_EXTRA_DELAY_MAX:
cycles += SNOP_EXTRA_DELAY_CYCLES
if DEBUG >= 3:
cycle = colored(f'C{self.cycle:>3}:', 'cyan')
# 20 (ALU) + 1 + 24 (FWD) + 1 + 28 (IQ) + 1 + 20 (events) = 95 padding after cycle
print(f"{cycle} {' ' * 95} {disasm(inst)}")
for _ in range(cycles): self.tick()
self.emit(IMMEDIATE, wave=self.wave_id)
elif isinstance(inst, SOPP) and inst.op == SOPPOp.S_ENDPGM:
# Drain pipeline before ending
while not self._pipeline_empty(): self.tick()
self.emit(WAVEEND, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3)
elif isinstance(inst, (VOP1, VOP2, VOP3)):
# Check for issue stall (no free in-flight slots)
while not self._can_issue():
self.tick()
# Issue: add to in_flight and issue_queue
srcs = _get_src_vgprs(inst)
dest = inst.vdst
# Clear stale state for this dest (WAW hazard)
self.completed.discard(dest)
if dest in self.fwd_slots: self.fwd_slots.remove(dest)
# Consumer adds producer to fwd_slots (if room and has dependency)
# If producer is in fwd_chain_broken, or we can't add, the chain breaks for this instruction too
has_fwd_slot = False
if srcs:
producer = srcs[0]
self.had_dependent.add(producer) # record that producer has a dependent
# Check if producer's forwarding chain is already broken
if producer in self.fwd_chain_broken:
# Chain is broken, this instruction also can't forward
self.fwd_chain_broken.add(dest)
elif len(self.fwd_slots) >= 4:
# Can't add producer, chain breaks
self.fwd_chain_broken.add(dest)
else:
# Can add producer
if producer not in self.fwd_slots:
self.fwd_slots.append(producer)
has_fwd_slot = len(self.fwd_slots) < 4
# Record if forwarding path was warm at issue time
was_warm = self.forward_warm
self.in_flight.append((dest, srcs, has_fwd_slot))
self.issue_queue.append((dest, srcs, 0, has_fwd_slot, was_warm))
self.emit(VALUINST, wave=self.wave_id)
if DEBUG >= 3:
cycle = colored(f'C{self.cycle:>3}:', 'cyan')
slot_info = "" if has_fwd_slot else colored(" NO_SLOT", 'red')
issue = colored(f'ISSUE v{dest}', 'magenta') + slot_info
padding = 95 - ansilen(issue)
print(f"{cycle} {issue}{' ' * padding} {disasm(inst)}")
# One cycle per instruction issued, then try to enter ALU
self.tick()
return
# One cycle per instruction issued (for non-VALU)
self.tick()
def emit_wavestart(self):
self.emit(WAVESTART, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3)
for _ in range(WAVESTART_TO_INST_CYCLES): self.tick()
# ═══════════════════════════════════════════════════════════════════════════════
# PROGRAM DECODE
# ═══════════════════════════════════════════════════════════════════════════════

View file

@ -119,6 +119,24 @@ class PacketType:
cls._extract_info = [(name, bf.lo, bf.mask(), cls._field_types.get(name)) for name, bf in cls._fields.items()]
cls._size_nibbles = ((max((f.hi for f in cls._fields.values()), default=0) + 4) // 4)
def __init__(self, _time: int = 0, **kwargs):
"""Construct packet from named fields (like assembly instructions)."""
raw = 0
if self._encoding:
bf, pattern = self._encoding
raw |= pattern << bf.lo
for name, bf in self._fields.items():
val = kwargs.get(name, 0)
if isinstance(val, IntEnum): val = val.value
raw |= (val & bf.mask()) << bf.lo
self._raw, self._time, self._values = raw, _time, {}
for name, lo, mask, enum_type in self._extract_info:
val = (raw >> lo) & mask
if enum_type is not None:
try: val = enum_type(val)
except ValueError: pass
self._values[name] = val
@classmethod
def from_raw(cls, raw: int, time: int = 0):
inst = object.__new__(cls)
@ -305,6 +323,8 @@ PACKET_TYPES: list[type[PacketType]] = [
NOP,
]
PACKET_BY_NAME: dict[str, type[PacketType]] = {cls.__name__: cls for cls in PACKET_TYPES}
def _build_state_table() -> tuple[bytes, dict[int, type[PacketType]]]:
table = [len(PACKET_TYPES) - 1] * 256 # default to NOP
opcode_to_class: dict[int, type[PacketType]] = {i: cls for i, cls in enumerate(PACKET_TYPES)}
@ -321,6 +341,11 @@ def _build_state_table() -> tuple[bytes, dict[int, type[PacketType]]]:
STATE_TO_OPCODE, OPCODE_TO_CLASS = _build_state_table()
OPCODE_TO_BYTES: dict[int, list[int]] = {}
for _byte_val, _opcode in enumerate(STATE_TO_OPCODE):
if _opcode not in OPCODE_TO_BYTES: OPCODE_TO_BYTES[_opcode] = []
OPCODE_TO_BYTES[_opcode].append(_byte_val)
# Precompute special case opcodes
_TS_DELTA_OR_MARK_OPCODE = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is TS_DELTA_OR_MARK)
_TS_DELTA_SHORT_OPCODE = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is TS_DELTA_SHORT)
@ -379,3 +404,47 @@ def decode(data: bytes) -> list[PacketType]:
packets_append(pkt_cls.from_raw(reg, time))
return packets
# ═══════════════════════════════════════════════════════════════════════════════
# ENCODER
# ═══════════════════════════════════════════════════════════════════════════════
def encode(packets: list[PacketType]) -> bytes:
"""Encode a list of packet instances into raw SQTT blob."""
if not packets: return b''
read_lengths = [16]
for p in packets[:-1]:
read_lengths.append(type(p)._size_nibbles)
total_nibbles = sum(read_lengths)
bits_arr = [0] * (total_nibbles * 4)
cumulative = 0
for i, p in enumerate(packets):
cumulative += read_lengths[i]
pkt_cls = type(p)
opcode = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is pkt_cls)
byte_vals = OPCODE_TO_BYTES.get(opcode)
if not byte_vals: raise ValueError(f"No encoding for {pkt_cls.__name__}")
opcode_byte = byte_vals[0]
delta_field = getattr(pkt_cls, 'delta', None)
if delta_field is not None and delta_field.hi < 8:
delta = p._values.get('delta', 0)
if isinstance(delta, IntEnum): delta = delta.value
if pkt_cls is TS_DELTA_SHORT: delta = max(0, delta - 8)
delta = delta & delta_field.mask()
opcode_byte = (opcode_byte & ~(delta_field.mask() << delta_field.lo)) | (delta << delta_field.lo)
opcode_nibble_pos = max(0, cumulative - 16)
opcode_bit_pos = opcode_nibble_pos * 4
for b in range(8):
if opcode_bit_pos + b < len(bits_arr):
bits_arr[opcode_bit_pos + b] = (opcode_byte >> b) & 1
nibbles = [sum(bits_arr[i + j] << j for j in range(4) if i + j < len(bits_arr)) for i in range(0, len(bits_arr), 4)]
while len(nibbles) % 2: nibbles.append(0)
return bytes(nibbles[i] | (nibbles[i + 1] << 4) for i in range(0, len(nibbles), 2))

View file

@ -0,0 +1,855 @@
#!/usr/bin/env python3
"""SQTT InstOp discovery tool - finds instruction opcodes by running different instructions.
Requires profiling enabled:
echo 'profile_standard' | sudo tee /sys/class/drm/card1/device/power_dpm_force_performance_level
Run with: DEBUG=1 python extra/assembly/amd/test/discover_instops.py
For full traces: DEBUG=2 python extra/assembly/amd/test/discover_instops.py
"""
import os
os.environ["SQTT"] = "1"
os.environ["PROFILE"] = "1"
os.environ["SQTT_LIMIT_SE"] = "2" # Force work to traced SE only
os.environ["SQTT_TOKEN_EXCLUDE"] = "3784" # Exclude WAVERDY, REG, EVENT, UTILCTR, WAVEALLOC, PERF
from tinygrad.helpers import DEBUG, colored
from tinygrad.runtime.ops_amd import SQTT_SIMD_SEL
from extra.assembly.amd.autogen.rdna3.ins import (
# VALU - basic (these are safe, just register ops)
v_mov_b32_e32, v_add_f32_e32, v_mul_f32_e32,
v_and_b32_e32, v_or_b32_e32, v_xor_b32_e32,
v_lshlrev_b32_e32, v_lshrrev_b32_e32,
# VALU - transcendental
v_exp_f32_e32, v_log_f32_e32, v_rcp_f32_e32, v_sqrt_f32_e32,
v_sin_f32_e32, v_cos_f32_e32,
# VALU - 64-bit
v_lshlrev_b64, v_lshrrev_b64, v_ashrrev_i64,
v_add_f64, v_mul_f64, v_max_f64, v_min_f64,
v_fma_f64,
# VALU - 64-bit transcendental
v_rcp_f64_e32, v_rsq_f64_e32, v_sqrt_f64_e32,
v_trunc_f64_e32, v_ceil_f64_e32, v_floor_f64_e32, v_fract_f64_e32,
v_frexp_exp_i32_f64_e32, v_frexp_mant_f64_e32,
# VALU - div helpers
v_div_fixup_f32, v_div_fixup_f64, v_div_fmas_f32, v_div_fmas_f64, v_div_scale_f32,
# VALU - MAD64
v_mad_u64_u32, v_mad_i64_i32,
# VALU - compare (writes to VCC, safe)
v_cmp_eq_u32_e32,
# VALU - cmpx (modifies EXEC) - various types
v_cmpx_eq_u32_e32, v_cmpx_lt_u32_e32, v_cmpx_gt_u32_e32,
v_cmpx_eq_f32_e32, v_cmpx_lt_f32_e32,
v_cmpx_eq_i32_e32,
v_cmpx_class_f32_e32,
# VALU - readlane/writelane
v_readlane_b32, v_writelane_b32,
v_readfirstlane_b32_e32,
# SALU - basic (safe, just register ops)
s_mov_b32, s_add_u32, s_and_b32, s_or_b32,
s_lshl_b32, s_lshr_b32,
s_nop, s_endpgm, s_waitcnt,
# SALU - float
s_ceil_f32, s_floor_f32, s_trunc_f32,
# SALU - branch (safe if offset is 0 = next instruction)
s_branch, s_cbranch_scc0, s_cbranch_execz, s_cbranch_execnz,
# SALU - message
s_sendmsg,
# SALU - bit manipulation
s_brev_b32, s_bcnt1_i32_b32, s_ctz_i32_b32, s_clz_i32_u32,
# SALU - saveexec (modifies EXEC)
s_and_saveexec_b32, s_or_saveexec_b32, s_xor_saveexec_b32,
# SMEM - scalar memory (load from kernarg pointer in s[0:1])
s_load_b32, s_load_b64,
# GLOBAL - global memory (load/store) - various widths
global_load_u8, global_load_u16, global_load_b32, global_load_b64, global_load_b96, global_load_b128,
global_store_b8, global_store_b16, global_store_b32, global_store_b64, global_store_b96, global_store_b128,
# GLOBAL - atomics
global_atomic_add_u32, global_atomic_add_u64,
# FLAT - flat memory access
flat_load_b32, flat_load_b64, flat_load_b96, flat_load_b128,
flat_store_b8, flat_store_b16, flat_store_b32, flat_store_b64, flat_store_b96, flat_store_b128,
# LDS - local data share - various widths
ds_load_b32, ds_load_b64, ds_load_b128,
ds_store_b32, ds_store_b64, ds_store_b128,
# LDS - atomics
ds_add_u32, ds_max_u32, ds_min_u32,
# VOP3P - packed
v_pk_add_f16, v_pk_mul_f16, v_pk_fma_f16, v_pk_add_i16,
# VOP3 - misc
v_bfe_u32, v_bfi_b32, v_alignbit_b32, v_fma_f32,
v_add3_u32, v_xad_u32, v_lshl_or_b32, v_add_nc_u32_e32,
# VOP3 - carry-out
v_add_co_u32, v_add_co_ci_u32_e32,
# VOPD - dual issue
v_dual_add_f32, v_dual_mul_f32,
# VOP2 - fmac
v_fmac_f32_e32,
# DOT
v_dot2_f16_f16,
# WMMA
v_wmma_f32_16x16x16_f16, v_wmma_f16_16x16x16_f16, v_wmma_i32_16x16x16_iu8,
# Permlane ops
v_permlane64_b32_e32, v_permlane16_b32, v_permlanex16_b32,
# Interpolation
v_interp_p10_f32, v_interp_p2_f32,
# Barrier
s_barrier,
# SrcEnum for NULL soffset
SrcEnum,
)
from extra.assembly.amd.dsl import v, s
from extra.assembly.amd.sqtt import InstOp, INST, WAVESTART, WAVEEND, ALUEXEC, VMEMEXEC
from extra.assembly.amd.test.test_sqtt_hw import (
run_asm_sqtt, decode_all_blobs, get_inst_ops, print_blobs, get_wave_packets, format_packet, PACKET_COLORS, count_valuinst
)
# ═══════════════════════════════════════════════════════════════════════════════
# INSTRUCTION TEST CASES - only safe instructions that don't access memory
# ═══════════════════════════════════════════════════════════════════════════════
# Helper: load buffer address from kernarg (s[0:1] -> s[2:3])
# The runtime passes kernarg pointer in s[0:1], kernarg contains buffer address
def _load_buf_addr():
return [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL), # load buf addr from kernarg
s_waitcnt(lgkmcnt=0), # wait for SMEM load
]
INSTRUCTION_TESTS: dict[str, tuple[str, list]] = {
# SALU (0x0) - scalar ALU, just register operations
"SALU_mov": ("s_mov_b32", [s_mov_b32(s[4], 0), s_mov_b32(s[5], 1)]),
"SALU_add": ("s_add_u32", [s_mov_b32(s[4], 1), s_mov_b32(s[5], 2), s_add_u32(s[6], s[4], s[5])]),
"SALU_logic": ("s_and/or", [s_and_b32(s[6], s[4], s[5]), s_or_b32(s[7], s[4], s[5])]),
"SALU_shift": ("s_lshl/lshr", [s_lshl_b32(s[6], s[4], 1), s_lshr_b32(s[7], s[4], 1)]),
"SALU_nop": ("s_nop", [s_nop(0)]),
# JUMP (0x3) - branch taken
"JUMP_branch": ("s_branch", [s_branch(0)]),
"JUMP_cbranch_execnz": ("s_cbranch_execnz", [s_cbranch_execnz(0)]), # EXEC != 0, branch taken
# JUMP_NO (0x4) - branch not taken
"JUMP_NO_cbranch_execz": ("s_cbranch_execz", [s_cbranch_execz(0)]), # EXEC != 0, branch not taken
# VALU (0xb) - vector ALU, just register operations
"VALU_mov": ("v_mov_b32", [v_mov_b32_e32(v[0], 0), v_mov_b32_e32(v[1], 1.0)]),
"VALU_add": ("v_add_f32", [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[1], 2.0), v_add_f32_e32(v[2], v[0], v[1])]),
"VALU_mul": ("v_mul_f32", [v_mul_f32_e32(v[2], v[0], v[1])]),
"VALU_logic": ("v_and/or/xor", [v_and_b32_e32(v[2], v[0], v[1]), v_or_b32_e32(v[3], v[0], v[1]), v_xor_b32_e32(v[4], v[0], v[1])]),
"VALU_shift": ("v_lshl/lshr", [v_lshlrev_b32_e32(v[2], 1, v[0]), v_lshrrev_b32_e32(v[3], 1, v[0])]),
# VALU transcendental - still just register ops
"VALU_exp": ("v_exp_f32", [v_mov_b32_e32(v[0], 1.0), v_exp_f32_e32(v[1], v[0])]),
"VALU_log": ("v_log_f32", [v_mov_b32_e32(v[0], 1.0), v_log_f32_e32(v[1], v[0])]),
"VALU_rcp": ("v_rcp_f32", [v_mov_b32_e32(v[0], 1.0), v_rcp_f32_e32(v[1], v[0])]),
"VALU_sqrt": ("v_sqrt_f32", [v_mov_b32_e32(v[0], 1.0), v_sqrt_f32_e32(v[1], v[0])]),
# VALU 64-bit shift (0xd)
"VALU64_lshl": ("v_lshlrev_b64", [v_lshlrev_b64(v[0:1], 1, v[2:3])]),
"VALU64_lshr": ("v_lshrrev_b64", [v_lshrrev_b64(v[0:1], 1, v[2:3])]),
"VALU64_ashr": ("v_ashrrev_i64", [v_ashrrev_i64(v[0:1], 1, v[2:3])]),
# VALU 64-bit arithmetic
"VALU64_add": ("v_add_f64", [v_add_f64(v[0:1], v[2:3], v[4:5])]),
"VALU64_mul": ("v_mul_f64", [v_mul_f64(v[0:1], v[2:3], v[4:5])]),
"VALU64_max": ("v_max_f64", [v_max_f64(v[0:1], v[2:3], v[4:5])]),
"VALU64_min": ("v_min_f64", [v_min_f64(v[0:1], v[2:3], v[4:5])]),
"VALU64_fma": ("v_fma_f64", [v_fma_f64(v[0:1], v[2:3], v[4:5], v[6:7])]),
# VALU 64-bit transcendental
"VALU64_rcp": ("v_rcp_f64", [v_rcp_f64_e32(v[0:1], v[2:3])]),
"VALU64_rsq": ("v_rsq_f64", [v_rsq_f64_e32(v[0:1], v[2:3])]),
"VALU64_sqrt": ("v_sqrt_f64", [v_sqrt_f64_e32(v[0:1], v[2:3])]),
# VALU 64-bit rounding
"VALU64_trunc": ("v_trunc_f64", [v_trunc_f64_e32(v[0:1], v[2:3])]),
"VALU64_ceil": ("v_ceil_f64", [v_ceil_f64_e32(v[0:1], v[2:3])]),
"VALU64_floor": ("v_floor_f64", [v_floor_f64_e32(v[0:1], v[2:3])]),
"VALU64_fract": ("v_fract_f64", [v_fract_f64_e32(v[0:1], v[2:3])]),
# VALU 64-bit frexp
"VALU64_frexp_exp": ("v_frexp_exp_i32_f64", [v_frexp_exp_i32_f64_e32(v[0], v[2:3])]),
"VALU64_frexp_mant": ("v_frexp_mant_f64", [v_frexp_mant_f64_e32(v[0:1], v[2:3])]),
# VALU 64-bit div helpers
"VALU64_div_fixup": ("v_div_fixup_f64", [v_div_fixup_f64(v[0:1], v[2:3], v[4:5], v[6:7])]),
"VALU64_div_fmas": ("v_div_fmas_f64", [v_div_fmas_f64(v[0:1], v[2:3], v[4:5], v[6:7])]),
# VALU 32-bit div helpers
"VALU_div_fixup": ("v_div_fixup_f32", [v_div_fixup_f32(v[0], v[1], v[2], v[3])]),
"VALU_div_fmas": ("v_div_fmas_f32", [v_div_fmas_f32(v[0], v[1], v[2], v[3])]),
"VALU_div_scale": ("v_div_scale_f32", [v_div_scale_f32(v[0], SrcEnum.VCC_LO, v[1], v[2], v[3])]),
# VALU MAD64 (0xe)
"VALU_mad64u": ("v_mad_u64_u32", [
v_mov_b32_e32(v[2], 2),
v_mov_b32_e32(v[3], 3),
v_mov_b32_e32(v[4], 0),
v_mov_b32_e32(v[5], 0),
v_mad_u64_u32(v[0:1], SrcEnum.NULL, v[2], v[3], v[4:5]),
]),
"VALU_mad64i": ("v_mad_i64_i32", [
v_mov_b32_e32(v[2], 2),
v_mov_b32_e32(v[3], 3),
v_mov_b32_e32(v[4], 0),
v_mov_b32_e32(v[5], 0),
v_mad_i64_i32(v[0:1], SrcEnum.NULL, v[2], v[3], v[4:5]),
]),
# VALU compare - writes to VCC
"VALU_cmp": ("v_cmp_eq_u32", [v_cmp_eq_u32_e32(v[0], v[1])]),
# VALU CMPX (0x73) - modifies EXEC
"VALU_cmpx_eq_u32": ("v_cmpx_eq_u32", [v_cmpx_eq_u32_e32(v[0], v[1])]),
# SALU saveexec (0x72) - modifies EXEC safely by ANDing with all-ones mask
"SALU_saveexec": ("s_and_saveexec_b32", [
s_mov_b32(s[5], 0xFFFFFFFF), # all lanes mask
s_and_saveexec_b32(s[4], s[5]), # EXEC = EXEC & 0xFFFFFFFF = EXEC (unchanged)
]),
# SALU float ops
"SALU_ceil": ("s_ceil_f32", [s_ceil_f32(s[4], s[5])]),
"SALU_floor": ("s_floor_f32", [s_floor_f32(s[4], s[5])]),
"SALU_trunc": ("s_trunc_f32", [s_trunc_f32(s[4], s[5])]),
# SALU bit ops
"SALU_brev": ("s_brev_b32", [s_brev_b32(s[4], s[5])]),
"SALU_bcnt1": ("s_bcnt1_i32_b32", [s_bcnt1_i32_b32(s[4], s[5])]),
"SALU_ctz": ("s_ctz_i32_b32", [s_ctz_i32_b32(s[4], s[5])]),
"SALU_clz": ("s_clz_i32_u32", [s_clz_i32_u32(s[4], s[5])]),
# VALU sin/cos
"VALU_sin": ("v_sin_f32", [v_sin_f32_e32(v[0], v[1])]),
"VALU_cos": ("v_cos_f32", [v_cos_f32_e32(v[0], v[1])]),
# VOP3P - packed operations
"VALU_pk_add_f16": ("v_pk_add_f16", [v_pk_add_f16(v[0], v[1], v[2])]),
"VALU_pk_mul_f16": ("v_pk_mul_f16", [v_pk_mul_f16(v[0], v[1], v[2])]),
"VALU_pk_fma_f16": ("v_pk_fma_f16", [v_pk_fma_f16(v[0], v[1], v[2], v[3])]),
"VALU_pk_add_i16": ("v_pk_add_i16", [v_pk_add_i16(v[0], v[1], v[2])]),
# VOP3 - misc
"VALU_bfe_u32": ("v_bfe_u32", [v_bfe_u32(v[0], v[1], 0, 8)]),
"VALU_bfi_b32": ("v_bfi_b32", [v_bfi_b32(v[0], v[1], v[2], v[3])]),
"VALU_alignbit": ("v_alignbit_b32", [v_alignbit_b32(v[0], v[1], v[2], 4)]),
"VALU_fma_f32": ("v_fma_f32", [v_fma_f32(v[0], v[1], v[2], v[3])]),
# VOP3 - integer add variants (used by tinygrad kernels)
"VALU_add3": ("v_add3_u32", [v_add3_u32(v[0], v[1], v[2], v[3])]),
"VALU_xad": ("v_xad_u32", [v_xad_u32(v[0], v[1], v[2], v[3])]),
"VALU_lshl_or": ("v_lshl_or_b32", [v_lshl_or_b32(v[0], v[1], 4, v[2])]),
"VALU_add_nc": ("v_add_nc_u32", [v_add_nc_u32_e32(v[0], v[1], v[2])]),
# VOP3 - carry-out adds (used for 64-bit address calculation)
"VALU_add_co": ("v_add_co_u32", [v_add_co_u32(v[0], SrcEnum.VCC_LO, v[1], v[2])]),
"VALU_add_co_ci": ("v_add_co_ci_u32", [v_add_co_ci_u32_e32(v[0], v[1], v[2])]),
# VOPD - dual issue (used by tinygrad kernels)
"VALU_dual_add": ("v_dual_add_f32", [v_dual_add_f32(v[0], v[1], v[2], v[3], v[4], v[5])]),
"VALU_dual_mul": ("v_dual_mul_f32", [v_dual_mul_f32(v[0], v[1], v[2], v[3], v[4], v[5])]),
# VOP2 - fmac
"VALU_fmac": ("v_fmac_f32", [v_fmac_f32_e32(v[0], v[1], v[0])]),
# DOT products
"VALU_dot2": ("v_dot2_f16_f16", [v_dot2_f16_f16(v[0], v[1], v[2], v[3])]),
# WMMA - wave matrix multiply accumulate
"VALU_wmma_f32_f16": ("v_wmma_f32_16x16x16_f16", [v_wmma_f32_16x16x16_f16(v[0:7], v[8:15], v[16:23], v[0:7])]),
"VALU_wmma_f16_f16": ("v_wmma_f16_16x16x16_f16", [v_wmma_f16_16x16x16_f16(v[0:7], v[8:15], v[16:23], v[0:7])]),
"VALU_wmma_i32_iu8": ("v_wmma_i32_16x16x16_iu8", [v_wmma_i32_16x16x16_iu8(v[0:7], v[8:11], v[12:15], v[0:7])]),
# Permlane operations - cross-lane data movement
# NOTE: permlane64 produces NO SQTT packets in wave32 mode (it's for wave64 pairs)
# NOTE: permlane16/x16 produce VALUINST packets (no specific InstOp)
"VALU_permlane16": ("v_permlane16_b32", [v_permlane16_b32(v[0], v[1], s[2], s[3])]),
"VALU_permlanex16": ("v_permlanex16_b32", [v_permlanex16_b32(v[0], v[1], s[2], s[3])]),
# Interpolation - used in graphics shaders (produces InstOp 0x12 VINTERP)
"VINTERP_p10": ("v_interp_p10_f32", [v_interp_p10_f32(v[0], v[1], v[2], v[3])]),
"VINTERP_p2": ("v_interp_p2_f32", [v_interp_p2_f32(v[0], v[1], v[2], v[3])]),
# Barrier - wave synchronization
# NOTE: s_barrier produces NO SQTT instruction packets (with 1 wave, it's essentially a no-op)
"SALU_barrier": ("s_barrier", [s_barrier()]),
# LDS atomics
"LDS_atomic_add": ("ds_add_u32", [
v_mov_b32_e32(v[0], 0), # LDS address
v_mov_b32_e32(v[1], 1), # data to add
ds_add_u32(addr=v[0], data0=v[1]),
s_waitcnt(lgkmcnt=0),
]),
# ═══════════════════════════════════════════════════════════════════════════════
# GLOBAL ATOMICS - access real buffer passed via kernarg
# ═══════════════════════════════════════════════════════════════════════════════
# GLOBAL atomic add 32-bit (0x28 GLOBAL_ATOMIC)
"GLOBAL_atomic_add": ("global_atomic_add_u32", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL), # load buf addr from kernarg
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0), # offset = 0
v_mov_b32_e32(v[1], 1), # data to add
global_atomic_add_u32(addr=v[0], data=v[1], saddr=s[2]),
s_waitcnt(vmcnt=0),
]),
# GLOBAL atomic add 64-bit
"GLOBAL_atomic_add64": ("global_atomic_add_u64", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[2], 1),
v_mov_b32_e32(v[3], 0),
global_atomic_add_u64(addr=v[0], data=v[2:3], saddr=s[2]),
s_waitcnt(vmcnt=0),
]),
# ═══════════════════════════════════════════════════════════════════════════════
# MEMORY INSTRUCTIONS - access real buffer passed via kernarg
# ═══════════════════════════════════════════════════════════════════════════════
# SMEM (0x1) - scalar memory load from buffer
"SMEM_load": ("s_load_b32", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL), # load buf addr from kernarg
s_waitcnt(lgkmcnt=0),
s_load_b32(s[4], s[2], 0, soffset=SrcEnum.NULL), # load from buffer
s_waitcnt(lgkmcnt=0),
]),
# GLOBAL load (0x21 GLOBAL_LOAD) - global memory load
"GLOBAL_load": ("global_load_b32", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL), # load buf addr from kernarg
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0), # offset = 0
global_load_b32(v[1], addr=v[0], saddr=s[2], offset=0), # load from buffer
s_waitcnt(vmcnt=0),
]),
# GLOBAL store (0x24 GLOBAL_STORE) - global memory store
"GLOBAL_store": ("global_store_b32", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL), # load buf addr from kernarg
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0), # offset = 0
v_mov_b32_e32(v[1], 42), # data to store
global_store_b32(addr=v[0], data=v[1], saddr=s[2], offset=0), # store to buffer
s_waitcnt(vmcnt=0),
]),
# GLOBAL 8-bit load/store
"GLOBAL_load8": ("global_load_u8", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0),
global_load_u8(v[1], addr=v[0], saddr=s[2], offset=0),
s_waitcnt(vmcnt=0),
]),
"GLOBAL_store8": ("global_store_b8", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[1], 42),
global_store_b8(addr=v[0], data=v[1], saddr=s[2], offset=0),
s_waitcnt(vmcnt=0),
]),
# GLOBAL 16-bit load/store
"GLOBAL_load16": ("global_load_u16", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0),
global_load_u16(v[1], addr=v[0], saddr=s[2], offset=0),
s_waitcnt(vmcnt=0),
]),
"GLOBAL_store16": ("global_store_b16", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[1], 42),
global_store_b16(addr=v[0], data=v[1], saddr=s[2], offset=0),
s_waitcnt(vmcnt=0),
]),
# LDS load (0x29 LDS_LOAD) - local data share read
"LDS_load": ("ds_load_b32", [
v_mov_b32_e32(v[0], 0), # LDS address = 0
ds_load_b32(v[1], v[0], offset=0), # read from LDS
s_waitcnt(lgkmcnt=0),
]),
# LDS store (0x2b LDS_STORE) - local data share write
"LDS_store": ("ds_store_b32", [
v_mov_b32_e32(v[0], 0), # LDS address = 0
v_mov_b32_e32(v[1], 42), # data to store
ds_store_b32(v[0], v[1], offset=0), # write to LDS
s_waitcnt(lgkmcnt=0),
]),
# ═══════════════════════════════════════════════════════════════════════════════
# WIDER MEMORY OPERATIONS - to discover more InstOp variants
# ═══════════════════════════════════════════════════════════════════════════════
# GLOBAL 64-bit load
"GLOBAL_load64": ("global_load_b64", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0),
global_load_b64(v[2:3], addr=v[0], saddr=s[2], offset=0),
s_waitcnt(vmcnt=0),
]),
# GLOBAL 96-bit load
"GLOBAL_load96": ("global_load_b96", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0),
global_load_b96(v[4:6], addr=v[0], saddr=s[2], offset=0),
s_waitcnt(vmcnt=0),
]),
# GLOBAL 128-bit load
"GLOBAL_load128": ("global_load_b128", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0),
global_load_b128(v[4:7], addr=v[0], saddr=s[2], offset=0),
s_waitcnt(vmcnt=0),
]),
# GLOBAL 64-bit store
"GLOBAL_store64": ("global_store_b64", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[2], 42),
v_mov_b32_e32(v[3], 43),
global_store_b64(addr=v[0], data=v[2:3], saddr=s[2], offset=0),
s_waitcnt(vmcnt=0),
]),
# GLOBAL 96-bit store
"GLOBAL_store96": ("global_store_b96", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[4], 42),
v_mov_b32_e32(v[5], 43),
v_mov_b32_e32(v[6], 44),
global_store_b96(addr=v[0], data=v[4:6], saddr=s[2], offset=0),
s_waitcnt(vmcnt=0),
]),
# GLOBAL 128-bit store
"GLOBAL_store128": ("global_store_b128", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[4], 42),
v_mov_b32_e32(v[5], 43),
v_mov_b32_e32(v[6], 44),
v_mov_b32_e32(v[7], 45),
global_store_b128(addr=v[0], data=v[4:7], saddr=s[2], offset=0),
s_waitcnt(vmcnt=0),
]),
# ═══════════════════════════════════════════════════════════════════════════════
# GLOBAL VADDR (vector-only addressing, saddr=NULL) - used by tinygrad kernels
# ═══════════════════════════════════════════════════════════════════════════════
# GLOBAL VADDR load (all sizes use same opcode 0x22)
"GLOBAL_VADDR_load": ("global_load_b32 vaddr", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
global_load_b32(v[4], addr=v[0:1], saddr=SrcEnum.NULL, offset=0),
s_waitcnt(vmcnt=0),
]),
"GLOBAL_VADDR_load128": ("global_load_b128 vaddr", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
global_load_b128(v[4:7], addr=v[0:1], saddr=SrcEnum.NULL, offset=0),
s_waitcnt(vmcnt=0),
]),
# GLOBAL VADDR stores (size encoded: 32->0x25, 64->0x26, 96->0x27, 128->0x28)
"GLOBAL_VADDR_store": ("global_store_b32 vaddr", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
v_mov_b32_e32(v[4], 42),
global_store_b32(addr=v[0:1], data=v[4], saddr=SrcEnum.NULL, offset=0),
s_waitcnt(vmcnt=0),
]),
"GLOBAL_VADDR_store64": ("global_store_b64 vaddr", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
v_mov_b32_e32(v[4], 42),
v_mov_b32_e32(v[5], 43),
global_store_b64(addr=v[0:1], data=v[4:5], saddr=SrcEnum.NULL, offset=0),
s_waitcnt(vmcnt=0),
]),
"GLOBAL_VADDR_store96": ("global_store_b96 vaddr", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
v_mov_b32_e32(v[4], 42),
v_mov_b32_e32(v[5], 43),
v_mov_b32_e32(v[6], 44),
global_store_b96(addr=v[0:1], data=v[4:6], saddr=SrcEnum.NULL, offset=0),
s_waitcnt(vmcnt=0),
]),
"GLOBAL_VADDR_store128": ("global_store_b128 vaddr", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
v_mov_b32_e32(v[4], 42),
v_mov_b32_e32(v[5], 43),
v_mov_b32_e32(v[6], 44),
v_mov_b32_e32(v[7], 45),
global_store_b128(addr=v[0:1], data=v[4:7], saddr=SrcEnum.NULL, offset=0),
s_waitcnt(vmcnt=0),
]),
# LDS 64-bit load
"LDS_load64": ("ds_load_b64", [
v_mov_b32_e32(v[0], 0),
ds_load_b64(v[2:3], v[0], offset=0),
s_waitcnt(lgkmcnt=0),
]),
# LDS 128-bit load
"LDS_load128": ("ds_load_b128", [
v_mov_b32_e32(v[0], 0),
ds_load_b128(v[4:7], v[0], offset=0),
s_waitcnt(lgkmcnt=0),
]),
# LDS 64-bit store
"LDS_store64": ("ds_store_b64", [
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[2], 42),
v_mov_b32_e32(v[3], 43),
ds_store_b64(v[0], v[2:3], offset=0),
s_waitcnt(lgkmcnt=0),
]),
# LDS 128-bit store
"LDS_store128": ("ds_store_b128", [
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[4], 42),
v_mov_b32_e32(v[5], 43),
v_mov_b32_e32(v[6], 44),
v_mov_b32_e32(v[7], 45),
ds_store_b128(v[0], v[4:7], offset=0),
s_waitcnt(lgkmcnt=0),
]),
# MESSAGE (0x9) - s_sendmsg
"MESSAGE": ("s_sendmsg", [
s_sendmsg(0), # send message 0 (NOP message)
]),
# ═══════════════════════════════════════════════════════════════════════════════
# FLAT MEMORY - uses 64-bit virtual address in VGPRs
# ═══════════════════════════════════════════════════════════════════════════════
# FLAT load - load using 64-bit address from buffer
"FLAT_load": ("flat_load_b32", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL), # load buf addr from kernarg
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]), # addr lo
v_mov_b32_e32(v[1], s[3]), # addr hi
flat_load_b32(v[2], addr=v[0:1]),
s_waitcnt(vmcnt=0, lgkmcnt=0),
]),
# FLAT store
"FLAT_store": ("flat_store_b32", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
v_mov_b32_e32(v[2], 42),
flat_store_b32(addr=v[0:1], data=v[2]),
s_waitcnt(vmcnt=0, lgkmcnt=0),
]),
# FLAT 64-bit
"FLAT_load64": ("flat_load_b64", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
flat_load_b64(v[2:3], addr=v[0:1]),
s_waitcnt(vmcnt=0, lgkmcnt=0),
]),
"FLAT_store64": ("flat_store_b64", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
v_mov_b32_e32(v[4], 42),
v_mov_b32_e32(v[5], 43),
flat_store_b64(addr=v[0:1], data=v[4:5]),
s_waitcnt(vmcnt=0, lgkmcnt=0),
]),
# FLAT 96-bit
"FLAT_load96": ("flat_load_b96", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
flat_load_b96(v[4:6], addr=v[0:1]),
s_waitcnt(vmcnt=0, lgkmcnt=0),
]),
"FLAT_store96": ("flat_store_b96", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
v_mov_b32_e32(v[4], 42),
v_mov_b32_e32(v[5], 43),
v_mov_b32_e32(v[6], 44),
flat_store_b96(addr=v[0:1], data=v[4:6]),
s_waitcnt(vmcnt=0, lgkmcnt=0),
]),
# FLAT 128-bit
"FLAT_load128": ("flat_load_b128", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
flat_load_b128(v[4:7], addr=v[0:1]),
s_waitcnt(vmcnt=0, lgkmcnt=0),
]),
"FLAT_store128": ("flat_store_b128", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
v_mov_b32_e32(v[4], 42),
v_mov_b32_e32(v[5], 43),
v_mov_b32_e32(v[6], 44),
v_mov_b32_e32(v[7], 45),
flat_store_b128(addr=v[0:1], data=v[4:7]),
s_waitcnt(vmcnt=0, lgkmcnt=0),
]),
# FLAT 8/16-bit stores
"FLAT_store8": ("flat_store_b8", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
v_mov_b32_e32(v[2], 42),
flat_store_b8(addr=v[0:1], data=v[2]),
s_waitcnt(vmcnt=0, lgkmcnt=0),
]),
"FLAT_store16": ("flat_store_b16", [
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
v_mov_b32_e32(v[2], 42),
flat_store_b16(addr=v[0:1], data=v[2]),
s_waitcnt(vmcnt=0, lgkmcnt=0),
]),
}
def run_with_retry(instructions: list, max_attempts: int = 20) -> tuple[list[tuple[int, list[bytes]]], list[list], set, int]:
"""Run instructions multiple times to collect InstOp variants.
Memory ops produce different InstOp values (0x2x vs 0x5x) depending on which SIMD executes them:
- 0x2x range: wave ran on traced SIMD (matched)
- 0x5x range: wave ran on other SIMD (not matched)
Returns list of (traced_simd, blobs) tuples, all_packets, all_ops, max_valuinst_count.
"""
all_ops = set()
all_runs: list[tuple[int, list[bytes]]] = []
all_packets = []
max_valuinst = 0
SQTT_SIMD_SEL.value = 0 # only trace SIMD 0
for _ in range(max_attempts):
blobs = run_asm_sqtt(instructions)
packets = decode_all_blobs(blobs)
# get ops and valuinst from all SIMDs
ops = set()
valuinst_count = 0
for simd in [0, 1, 2, 3]:
ops.update(get_inst_ops(packets, traced_simd=simd))
valuinst_count = max(valuinst_count, count_valuinst(packets, traced_simd=simd))
all_runs.append((0, blobs))
all_packets.append(packets)
all_ops.update(ops)
max_valuinst = max(max_valuinst, valuinst_count)
return all_runs, all_packets, all_ops, max_valuinst
def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception], dict[str, int]]:
"""Run all instruction tests and collect InstOp values."""
discovered: dict[int, set[str]] = {}
failures: dict[str, Exception] = {}
valuinst_tests: dict[str, int] = {} # tests that produced VALUINST packets
for test_name, (instr_name, instructions) in INSTRUCTION_TESTS.items():
try:
all_runs, _, ops, valuinst_count = run_with_retry(instructions)
for op in ops:
if op not in discovered:
discovered[op] = set()
discovered[op].add(f"{test_name}")
if valuinst_count > 0:
valuinst_tests[test_name] = valuinst_count
if DEBUG >= 2:
print(f"\n{''*60}")
print(f"{test_name} ({instr_name}): ops={[hex(op) for op in sorted(ops)]}")
# collect wave patterns from traced SIMD runs (group by exact timing)
patterns: dict[tuple, list] = {} # pattern (types + timing) -> list of (wave_packets, t0)
for traced_simd, blobs in all_runs:
for blob in blobs:
packets = decode_all_blobs([blob])
wave_packets = get_wave_packets(packets)
# only include runs where wave ran on traced SIMD
ws = next((p for p in wave_packets if isinstance(p, WAVESTART)), None)
if ws and ws.simd == traced_simd and wave_packets:
t0 = wave_packets[0]._time
# pattern includes types AND normalized timing
pattern = tuple((type(p).__name__, p._time - t0) for p in wave_packets)
if pattern not in patterns:
patterns[pattern] = []
patterns[pattern].append((wave_packets, t0))
if patterns:
counts = {p: len(runs) for p, runs in patterns.items()}
most_common = max(counts, key=counts.get)
count = counts[most_common]
total = sum(counts.values())
print(f"\n=== most common pattern ({count}/{total} runs) ===")
wave_packets, t0 = patterns[most_common][0]
last_time = t0
for p in wave_packets:
print(format_packet(p, last_time, t0))
last_time = p._time
if len(patterns) > 1:
print(f"\n variations: {len(patterns)} unique timing patterns")
if DEBUG >= 3:
for traced_simd, blobs in all_runs:
print(f"\n=== traced simd={traced_simd} ===")
print_blobs(blobs, wave_only=False)
if DEBUG >= 1:
status = colored("", "green") if ops else (colored("V", "cyan") if valuinst_count > 0 else colored("", "yellow"))
ops_str = ", ".join(hex(op) for op in sorted(ops)) if ops else "none"
valuinst_str = f" valuinst={valuinst_count}" if valuinst_count > 0 and not ops else ""
print(f" {status} {test_name:25s} ops=[{ops_str}]{valuinst_str}")
except Exception as e:
failures[test_name] = e
if DEBUG >= 1:
print(f" {colored('', 'red')} {test_name:25s} FAILED: {e}")
return discovered, failures, valuinst_tests
def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception], valuinst_tests: dict[str, int]) -> None:
"""Print discovery summary."""
known_ops = {e.value for e in InstOp}
discovered_ops = set(discovered.keys())
print("\n" + "=" * 60)
print("DISCOVERED INSTOP VALUES")
print("=" * 60)
for op in sorted(discovered_ops):
try:
name = InstOp(op).name
status = colored("known", "green")
except ValueError:
name = f"UNKNOWN"
status = colored("NEW!", "yellow")
sources = ", ".join(sorted(discovered[op]))
print(f" 0x{op:02x} {name:20s} ({status}) <- {sources}")
# VALUINST tests (instructions that only produce VALUINST, not INST packets)
valuinst_only = {k: v for k, v in valuinst_tests.items() if not any(k in tests for tests in discovered.values())}
if valuinst_only:
print("\n" + "=" * 60)
print(colored("VALUINST-ONLY INSTRUCTIONS (no InstOp, use VALUINST packet)", "cyan"))
print("=" * 60)
for test_name, count in sorted(valuinst_only.items()):
print(f" {test_name}: {count} VALUINST packets")
# Missing from enum
missing = known_ops - discovered_ops
if missing:
print("\n" + "=" * 60)
print("ENUM VALUES NOT DISCOVERED")
print("=" * 60)
print("(need memory ops: SMEM, VMEM, LDS)")
for op in sorted(missing):
print(f" 0x{op:02x} {InstOp(op).name}")
# New values to add
new_ops = discovered_ops - known_ops
if new_ops:
print("\n" + "=" * 60)
print(colored("NEW INSTOP VALUES TO ADD TO ENUM", "yellow"))
print("=" * 60)
for op in sorted(new_ops):
sources = ", ".join(sorted(discovered[op]))
print(f" {op:#04x}: \"{sources}\",")
# Stats
print("\n" + "=" * 60)
print("STATISTICS")
print("=" * 60)
print(f" Tests run: {len(INSTRUCTION_TESTS)}")
print(f" Tests passed: {len(INSTRUCTION_TESTS) - len(failures)}")
print(f" Tests failed: {len(failures)}")
print(f" Known ops: {len(known_ops)}")
print(f" Discovered: {len(discovered_ops)}")
if known_ops:
print(f" Coverage: {len(discovered_ops & known_ops)}/{len(known_ops)} ({100*len(discovered_ops & known_ops)//len(known_ops)}%)")
print(f" New ops found: {len(new_ops)}")
print(f" VALUINST-only: {len(valuinst_only)}")
if __name__ == "__main__":
print("=" * 60)
print("SQTT InstOp Discovery Tool")
print("=" * 60)
print(f"Testing {len(INSTRUCTION_TESTS)} instruction categories...\n")
discovered, failures, valuinst_tests = discover_all_instops()
print_summary(discovered, failures, valuinst_tests)

View file

@ -0,0 +1,289 @@
#!/usr/bin/env python3
"""SQTT InstOp discovery from tinygrad-generated kernels.
Runs various tinygrad operations and captures SQTT traces to find new InstOp values.
Requires profiling enabled:
echo 'profile_standard' | sudo tee /sys/class/drm/card1/device/power_dpm_force_performance_level
Run with: DEBUG=1 python extra/assembly/amd/test/discover_instops_tensor.py
For full traces: DEBUG=2 python extra/assembly/amd/test/discover_instops_tensor.py
"""
import os
os.environ["SQTT"] = "1"
os.environ["PROFILE"] = "1"
os.environ["SQTT_LIMIT_SE"] = "2" # Force work to traced SE only
os.environ["SQTT_TOKEN_EXCLUDE"] = "3784" # Exclude noisy packet types
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import DEBUG, colored
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, SQTT_SIMD_SEL
from extra.assembly.amd.sqtt import InstOp, decode, INST, WAVESTART, WAVEEND
# ═══════════════════════════════════════════════════════════════════════════════
# HELPERS
# ═══════════════════════════════════════════════════════════════════════════════
def get_inst_ops_from_blobs(blobs: list[bytes]) -> set[int]:
"""Extract all InstOp values from SQTT blobs."""
ops = set()
for blob in blobs:
packets = decode(blob)
in_wave = False
for p in packets:
if isinstance(p, WAVESTART):
in_wave = True
if in_wave and isinstance(p, INST):
ops.add(p.op if isinstance(p.op, int) else p.op.value)
if isinstance(p, WAVEEND):
in_wave = False
return ops
def run_and_capture(fn, attempts: int = 5) -> tuple[set[int], list[bytes]]:
"""Run a function multiple times and collect SQTT traces."""
dev = Device["AMD"]
all_ops = set()
all_blobs = []
SQTT_SIMD_SEL.value = 0
for _ in range(attempts):
dev.profile_events.clear()
fn()
blobs = [ev.blob for ev in dev.profile_events if isinstance(ev, ProfileSQTTEvent)]
ops = get_inst_ops_from_blobs(blobs)
all_ops.update(ops)
all_blobs.extend(blobs)
return all_ops, all_blobs
# ═══════════════════════════════════════════════════════════════════════════════
# TENSOR OPERATIONS TO TEST
# ═══════════════════════════════════════════════════════════════════════════════
TENSOR_TESTS: dict[str, tuple[str, callable]] = {
# Basic arithmetic
"add_f32": ("tensor add float32", lambda: (Tensor.rand(1024) + Tensor.rand(1024)).realize()),
"mul_f32": ("tensor mul float32", lambda: (Tensor.rand(1024) * Tensor.rand(1024)).realize()),
"sub_f32": ("tensor sub float32", lambda: (Tensor.rand(1024) - Tensor.rand(1024)).realize()),
"div_f32": ("tensor div float32", lambda: (Tensor.rand(1024) / (Tensor.rand(1024) + 0.1)).realize()),
# Transcendental
"exp_f32": ("tensor exp float32", lambda: Tensor.rand(1024).exp().realize()),
"log_f32": ("tensor log float32", lambda: (Tensor.rand(1024) + 0.1).log().realize()),
"sqrt_f32": ("tensor sqrt float32", lambda: Tensor.rand(1024).sqrt().realize()),
"sin_f32": ("tensor sin float32", lambda: Tensor.rand(1024).sin().realize()),
"cos_f32": ("tensor cos float32", lambda: Tensor.rand(1024).cos().realize()),
"tanh_f32": ("tensor tanh float32", lambda: Tensor.rand(1024).tanh().realize()),
"sigmoid_f32": ("tensor sigmoid float32", lambda: Tensor.rand(1024).sigmoid().realize()),
# Reductions
"sum_f32": ("tensor sum float32", lambda: Tensor.rand(1024).sum().realize()),
"max_f32": ("tensor max float32", lambda: Tensor.rand(1024).max().realize()),
"mean_f32": ("tensor mean float32", lambda: Tensor.rand(1024).mean().realize()),
# Matmul - small
"matmul_small": ("matmul 32x32", lambda: (Tensor.rand(32, 32) @ Tensor.rand(32, 32)).realize()),
# Matmul - medium (might use WMMA)
"matmul_medium": ("matmul 128x128", lambda: (Tensor.rand(128, 128) @ Tensor.rand(128, 128)).realize()),
# Matmul - larger (more likely to use WMMA)
"matmul_large": ("matmul 256x256", lambda: (Tensor.rand(256, 256) @ Tensor.rand(256, 256)).realize()),
# Different dtypes
"add_f16": ("tensor add float16", lambda: (Tensor.rand(1024, dtype=dtypes.float16) + Tensor.rand(1024, dtype=dtypes.float16)).realize()),
"mul_f16": ("tensor mul float16", lambda: (Tensor.rand(1024, dtype=dtypes.float16) * Tensor.rand(1024, dtype=dtypes.float16)).realize()),
"matmul_f16": ("matmul float16 128x128", lambda: (Tensor.rand(128, 128, dtype=dtypes.float16) @ Tensor.rand(128, 128, dtype=dtypes.float16)).realize()),
# Integer ops
"add_i32": ("tensor add int32", lambda: (Tensor.randint(1024, high=1000) + Tensor.randint(1024, high=1000)).realize()),
"mul_i32": ("tensor mul int32", lambda: (Tensor.randint(1024, high=100) * Tensor.randint(1024, high=100)).realize()),
# Bitwise
"and_i32": ("tensor bitwise and", lambda: (Tensor.randint(1024, high=1000) & Tensor.randint(1024, high=1000)).realize()),
"or_i32": ("tensor bitwise or", lambda: (Tensor.randint(1024, high=1000) | Tensor.randint(1024, high=1000)).realize()),
"xor_i32": ("tensor bitwise xor", lambda: (Tensor.randint(1024, high=1000) ^ Tensor.randint(1024, high=1000)).realize()),
"lshift_i32": ("tensor left shift", lambda: (Tensor.randint(1024, high=1000) << 2).realize()),
"rshift_i32": ("tensor right shift", lambda: (Tensor.randint(1024, high=1000) >> 2).realize()),
# Comparisons
"cmp_eq": ("tensor compare eq", lambda: (Tensor.rand(1024) == 0.5).realize()),
"cmp_lt": ("tensor compare lt", lambda: (Tensor.rand(1024) < 0.5).realize()),
"cmp_gt": ("tensor compare gt", lambda: (Tensor.rand(1024) > 0.5).realize()),
# Where/select
"where": ("tensor where", lambda: Tensor.rand(1024).where(Tensor.rand(1024), Tensor.rand(1024)).realize()),
# Reshaping/movement (may not generate interesting ops but let's check)
"reshape": ("tensor reshape", lambda: Tensor.rand(32, 32).reshape(16, 64).realize()),
"permute": ("tensor permute", lambda: Tensor.rand(32, 32).permute(1, 0).contiguous().realize()),
"expand": ("tensor expand", lambda: Tensor.rand(1, 32).expand(32, 32).contiguous().realize()),
# Pad
"pad": ("tensor pad", lambda: Tensor.rand(30, 30).pad(((1, 1), (1, 1))).realize()),
# Conv2D - small
"conv2d_small": ("conv2d 3x3", lambda: Tensor.rand(1, 3, 32, 32).conv2d(Tensor.rand(8, 3, 3, 3)).realize()),
# Conv2D - larger
"conv2d_medium": ("conv2d 3x3 64ch", lambda: Tensor.rand(1, 64, 32, 32).conv2d(Tensor.rand(64, 64, 3, 3)).realize()),
# Pooling
"maxpool": ("max pool 2x2", lambda: Tensor.rand(1, 3, 32, 32).max_pool2d((2, 2)).realize()),
"avgpool": ("avg pool 2x2", lambda: Tensor.rand(1, 3, 32, 32).avg_pool2d((2, 2)).realize()),
# Softmax
"softmax": ("softmax", lambda: Tensor.rand(32, 128).softmax().realize()),
# LayerNorm-like
"layernorm": ("layer norm pattern", lambda: _layernorm(Tensor.rand(32, 128))),
# BatchNorm-like
"batchnorm": ("batch norm pattern", lambda: _batchnorm(Tensor.rand(1, 64, 32, 32))),
# Dropout-like (during training)
"dropout": ("dropout pattern", lambda: (Tensor.rand(1024) * (Tensor.rand(1024) > 0.5)).realize()),
# Cast operations
"cast_f32_to_f16": ("cast f32->f16", lambda: Tensor.rand(1024).cast(dtypes.float16).realize()),
"cast_f16_to_f32": ("cast f16->f32", lambda: Tensor.rand(1024, dtype=dtypes.float16).cast(dtypes.float32).realize()),
"cast_f32_to_i32": ("cast f32->i32", lambda: (Tensor.rand(1024) * 100).cast(dtypes.int32).realize()),
"cast_i32_to_f32": ("cast i32->f32", lambda: Tensor.randint(1024, high=100).cast(dtypes.float32).realize()),
# Clamp/clip
"clamp": ("tensor clamp", lambda: Tensor.rand(1024).clamp(0.2, 0.8).realize()),
# Abs/neg
"abs": ("tensor abs", lambda: (Tensor.rand(1024) - 0.5).abs().realize()),
"neg": ("tensor neg", lambda: (-Tensor.rand(1024)).realize()),
# Reciprocal
"recip": ("tensor reciprocal", lambda: (Tensor.rand(1024) + 0.1).reciprocal().realize()),
# Power
"pow2": ("tensor pow 2", lambda: (Tensor.rand(1024) ** 2).realize()),
"pow3": ("tensor pow 3", lambda: (Tensor.rand(1024) ** 3).realize()),
}
def _layernorm(x: Tensor) -> Tensor:
"""Simple layer normalization pattern."""
mean = x.mean(axis=-1, keepdim=True)
var = ((x - mean) ** 2).mean(axis=-1, keepdim=True)
return ((x - mean) / (var + 1e-5).sqrt()).realize()
def _batchnorm(x: Tensor) -> Tensor:
"""Simple batch normalization pattern."""
mean = x.mean(axis=(0, 2, 3), keepdim=True)
var = ((x - mean) ** 2).mean(axis=(0, 2, 3), keepdim=True)
return ((x - mean) / (var + 1e-5).sqrt()).realize()
# ═══════════════════════════════════════════════════════════════════════════════
# DISCOVERY
# ═══════════════════════════════════════════════════════════════════════════════
def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception]]:
"""Run all tensor tests and collect InstOp values."""
discovered: dict[int, set[str]] = {}
failures: dict[str, Exception] = {}
for test_name, (desc, fn) in TENSOR_TESTS.items():
try:
ops, blobs = run_and_capture(fn)
for op in ops:
if op not in discovered:
discovered[op] = set()
discovered[op].add(test_name)
if DEBUG >= 1:
status = colored("", "green") if ops else colored("", "yellow")
ops_str = ", ".join(hex(op) for op in sorted(ops)) if ops else "none"
print(f" {status} {test_name:25s} [{desc:25s}] ops=[{ops_str}]")
if DEBUG >= 2 and blobs:
# Show first wave trace
for blob in blobs[:1]:
packets = decode(blob)
print(f" First blob: {len(blob)} bytes, {len(packets)} packets")
except Exception as e:
failures[test_name] = e
if DEBUG >= 1:
print(f" {colored('', 'red')} {test_name:25s} FAILED: {e}")
return discovered, failures
def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception]) -> None:
"""Print discovery summary."""
known_ops = {e.value for e in InstOp}
discovered_ops = set(discovered.keys())
print("\n" + "=" * 70)
print("DISCOVERED INSTOP VALUES FROM TINYGRAD KERNELS")
print("=" * 70)
for op in sorted(discovered_ops):
try:
name = InstOp(op).name
status = colored("known", "green")
except ValueError:
name = "UNKNOWN"
status = colored("NEW!", "yellow")
sources = ", ".join(sorted(discovered[op]))
# Truncate sources if too long
if len(sources) > 60:
sources = sources[:57] + "..."
print(f" 0x{op:02x} {name:20s} ({status}) <- {sources}")
# New values to add
new_ops = discovered_ops - known_ops
if new_ops:
print("\n" + "=" * 70)
print(colored("NEW INSTOP VALUES TO ADD TO ENUM", "yellow"))
print("=" * 70)
for op in sorted(new_ops):
sources = ", ".join(sorted(discovered[op]))
print(f" 0x{op:02x}: discovered from [{sources}]")
# Missing from enum (not discovered)
missing = known_ops - discovered_ops
if missing:
print("\n" + "=" * 70)
print("ENUM VALUES NOT DISCOVERED (may need specific instruction patterns)")
print("=" * 70)
for op in sorted(missing):
print(f" 0x{op:02x} {InstOp(op).name}")
# Stats
print("\n" + "=" * 70)
print("STATISTICS")
print("=" * 70)
print(f" Tests run: {len(TENSOR_TESTS)}")
print(f" Tests passed: {len(TENSOR_TESTS) - len(failures)}")
print(f" Tests failed: {len(failures)}")
print(f" Known ops: {len(known_ops)}")
print(f" Discovered: {len(discovered_ops)}")
if known_ops:
coverage = len(discovered_ops & known_ops)
print(f" Coverage: {coverage}/{len(known_ops)} ({100*coverage//len(known_ops)}%)")
print(f" New ops found: {len(new_ops)}")
if failures:
print("\n" + "=" * 70)
print("FAILURES")
print("=" * 70)
for name, e in failures.items():
print(f" {name}: {e}")
if __name__ == "__main__":
print("=" * 70)
print("SQTT InstOp Discovery from Tinygrad Kernels")
print("=" * 70)
print(f"Testing {len(TENSOR_TESTS)} tensor operations...\n")
discovered, failures = discover_all_instops()
print_summary(discovered, failures)

View file

@ -0,0 +1,79 @@
#!/usr/bin/env python3
"""Tests for SQTT packet codec (no hardware required)."""
import unittest
from extra.assembly.amd.sqtt import (
LAYOUT_HEADER, WAVESTART, WAVEEND, INST, NOP,
decode, encode, PACKET_TYPES, OPCODE_TO_CLASS
)
class TestSQTTCodec(unittest.TestCase):
"""Tests for SQTT encoder/decoder roundtrip."""
def test_roundtrip_simple(self):
"""Test encode/decode roundtrip for simple packets."""
test_packets = [
LAYOUT_HEADER.from_raw(0x100),
WAVESTART.from_raw(0x0),
INST.from_raw(0x10), # delta=1
INST.from_raw(0x10), # delta=1
WAVEEND.from_raw(0x40), # delta=2
]
encoded = encode(test_packets)
decoded = decode(encoded)
self.assertGreaterEqual(len(decoded), len(test_packets))
for i, (orig, dec) in enumerate(zip(test_packets, decoded)):
self.assertEqual(type(orig), type(dec), f"type mismatch at {i}")
def test_decode_empty(self):
"""Test decoding empty data."""
packets = decode(b'')
self.assertEqual(packets, [])
def test_encode_empty(self):
"""Test encoding empty list."""
data = encode([])
self.assertEqual(data, b'')
def test_all_packet_types_have_encoding(self):
"""All packet types should have an encoding defined."""
for pkt_cls in PACKET_TYPES:
self.assertIsNotNone(pkt_cls._encoding, f"{pkt_cls.__name__} missing encoding")
def test_packet_from_raw(self):
"""Test creating packets from raw values."""
# INST with wave=5, op=0x21, delta=2
raw = (0x21 << 13) | (5 << 8) | (2 << 4) | 0b010
pkt = INST.from_raw(raw)
self.assertEqual(pkt.wave, 5)
self.assertEqual(pkt.op, 0x21)
self.assertEqual(pkt.delta, 2)
class TestDecodeRealBlob(unittest.TestCase):
"""Test decoding real SQTT blobs from examples."""
def test_decode_example_file(self):
"""Test decoding a real SQTT blob from examples."""
import pickle
from pathlib import Path
example_path = Path(__file__).parent.parent.parent.parent / "sqtt/examples/profile_plus_run_0.pkl"
if not example_path.exists():
self.skipTest(f"Example file not found: {example_path}")
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
with open(example_path, "rb") as f:
data = pickle.load(f)
sqtt_events = [e for e in data if isinstance(e, ProfileSQTTEvent)]
self.assertGreater(len(sqtt_events), 0, "No SQTT events in example")
packets = decode(sqtt_events[0].blob)
self.assertGreater(len(packets), 0, "No packets decoded")
# First packet should be LAYOUT_HEADER
self.assertIsInstance(packets[0], LAYOUT_HEADER)
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,545 @@
#!/usr/bin/env python3
"""Tests for SQTT emulator correctness against known hardware patterns.
NOTE: This file only tests NOP and VALU behavior. For WMMA/DP/trans tests,
see test_sqtt_compare.py.
Run emulator tests: PYTHONPATH="." python3 extra/assembly/amd/test/test_sqtt_correct.py
Run hardware tests: SQTT_HW=1 PYTHONPATH="." python3 extra/assembly/amd/test/test_sqtt_correct.py
"""
import os
import unittest
USE_HW = os.environ.get("SQTT_HW", "0") == "1"
if USE_HW:
os.environ["SQTT"] = "1"
os.environ["PROFILE"] = "1"
os.environ["SQTT_LIMIT_SE"] = "2"
os.environ["SQTT_TOKEN_EXCLUDE"] = "3784"
from extra.assembly.amd.emu import SQTTState, decode_program, exec_wave, WaveState, LDSMem
from extra.assembly.amd.sqtt import WAVESTART, WAVEEND
from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, s_nop, s_endpgm, s_delay_alu
from extra.assembly.amd.dsl import v
def assemble(instructions: list) -> bytes:
return b''.join(inst.to_bytes() for inst in instructions)
def wrap_with_nops(instructions: list, nops=16) -> list:
return instructions + [s_nop(0)]*nops + [s_endpgm()]
def get_wave_packets(packets: list) -> list:
result, in_wave = [], False
for p in packets:
if isinstance(p, WAVESTART) and p.simd == 0:
in_wave, result = True, [p]
elif in_wave:
result.append(p)
if isinstance(p, WAVEEND): break
return result
def get_timing_deltas(packets: list) -> list[tuple[str, int]]:
skip_types = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3", "REG"}
filtered = [p for p in packets if type(p).__name__ not in skip_types]
if not filtered: return []
result = [(type(filtered[0]).__name__, 0)]
for i in range(1, len(filtered)):
result.append((type(filtered[i]).__name__, filtered[i]._time - filtered[i-1]._time))
return result
def run_emulator(instructions: list) -> list:
code = assemble(instructions)
program = decode_program(code)
st = WaveState()
st.exec_mask = (1 << 32) - 1
lds = LDSMem(bytearray(65536))
trace = SQTTState(wave_id=0, simd=0, cu=0)
exec_wave(program, st, lds, 32, trace)
return get_wave_packets(trace.packets)
def get_all_waves(packets: list) -> list[list]:
"""Extract all WAVESTART..WAVEEND ranges on simd 0."""
waves, in_wave, current = [], False, []
for p in packets:
if isinstance(p, WAVESTART) and p.simd == 0:
in_wave, current = True, [p]
elif in_wave:
current.append(p)
if isinstance(p, WAVEEND):
waves.append(current)
in_wave, current = False, []
return waves
def run_hardware(instructions: list) -> list:
from extra.assembly.amd.test.test_sqtt_hw import compile_asm_sqtt, run_prg_sqtt_batch
from extra.assembly.amd.sqtt import decode
from collections import Counter
prg = compile_asm_sqtt(instructions, alu_only=True)
for _ in range(10):
blobs = run_prg_sqtt_batch(prg, n_runs=200)
# Extract all waves from all blobs
traces = []
for blob in blobs:
traces.extend(get_all_waves(decode(blob)))
if not traces:
continue
# Find most common pattern
delta_sets = [tuple(get_timing_deltas(t)) for t in traces]
most_common = Counter(delta_sets).most_common(1)[0][0]
for t in traces:
if tuple(get_timing_deltas(t)) == most_common:
return t
return []
def run_sqtt(instructions: list, nops: int = 16) -> list:
instructions = wrap_with_nops(instructions, nops=nops)
return run_hardware(instructions) if USE_HW else run_emulator(instructions)
def get_deltas(instructions: list) -> tuple[list[int], list[int]]:
"""Run and return (issue deltas, exec deltas).
Issue = IMMEDIATE + VALUINST, Exec = ALUEXEC.
Deltas are between consecutive packets of same stream."""
deltas = get_timing_deltas(run_sqtt(instructions))
time = 0
issue_times, exec_times = [], []
for ptype, delta in deltas:
time += delta
if ptype in ('IMMEDIATE', 'VALUINST'):
issue_times.append(time)
elif ptype == 'ALUEXEC':
exec_times.append(time)
issue = [issue_times[i] - issue_times[i-1] for i in range(1, len(issue_times))]
execd = [exec_times[i] - exec_times[i-1] for i in range(1, len(exec_times))]
return issue, execd
# ************************************ tests ************************************
class TestVALUChains(unittest.TestCase):
"""VALU dependency chains."""
def _chain(self, n, expected_issue, expected_exec):
instrs = [v_mov_b32_e32(v[0], 1.0)] + [v_add_f32_e32(v[i], v[i-1], v[i-1]) for i in range(1, n)]
issue, execd = get_deltas(instrs)
self.assertEqual(issue[:n-1], expected_issue)
if isinstance(expected_exec[0], list): self.assertIn(execd, expected_exec)
else: self.assertEqual(execd, expected_exec)
def test_chain_2(self): self._chain(2, [1], [6])
def test_chain_3(self): self._chain(3, [1, 1], [6, 5])
def test_chain_4(self): self._chain(4, [1, 1, 1], [6, 5, 5])
def test_chain_5(self): self._chain(5, [1, 1, 1, 1], [6, 5, 5, 9])
def test_chain_6(self): self._chain(6, [1, 1, 1, 1, 1], [6, 5, 5, 9, 9])
def test_chain_7(self): self._chain(7, [1, 1, 1, 1, 1, 1], [6, 5, 5, 5, 9, 9])
def test_chain_8(self): self._chain(8, [1, 1, 1, 1, 1, 1, 1], [6, 5, 5, 5, 9, 9, 9])
# NOTE: position 8 can be 5 or 9 depending on GPU variant
def test_chain_12(self): self._chain(12, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [[6, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9], [6, 5, 5, 5, 5, 9, 9, 9, 5, 9, 9]])
def test_chain_14(self): self._chain(14, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [[6, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9], [6, 5, 5, 5, 5, 9, 9, 9, 5, 9, 9, 9, 9]])
# issue stalls start here
def test_chain_15(self): self._chain(15, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3], [[6, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9], [6, 5, 5, 5, 5, 5, 9, 9, 5, 9, 9, 9, 9, 9]])
def test_chain_16(self): self._chain(16, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 5], [[6, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9], [6, 5, 5, 5, 5, 5, 5, 9, 5, 9, 9, 9, 9, 9, 9]])
def test_chain_18(self): self._chain(18, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 5, 5, 5], [6, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9])
def test_chain_20(self): self._chain(20, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 5, 5, 5, 5, 5], [6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9])
class TestVALUChainsWithWarmup(unittest.TestCase):
"""VALU dependency chains with early VALUs to isolate warmup effects."""
# just the first stupid VALU takes 6
def _chain(self, n, warmup=True):
instrs = [v_mov_b32_e32(v[0], 1.0), s_nop(100)] if warmup else [s_nop(100)]
instrs += [v_mov_b32_e32(v[0], 1.0)] + [v_add_f32_e32(v[i], v[i-1], v[i-1]) for i in range(1, n)]
issue, execd = get_deltas(instrs)
return execd[1:] if warmup else execd
def test_warmup_chain_2(self): self.assertEqual(self._chain(2), [5])
def test_warmup_chain_3(self): self.assertEqual(self._chain(3), [5, 5])
def test_warmup_chain_4(self): self.assertEqual(self._chain(4), [5, 5, 5])
def test_warmup_chain_5(self): self.assertEqual(self._chain(5), [5, 5, 5, 9])
def test_warmup_chain_6(self): self.assertEqual(self._chain(6), [5, 5, 5, 5, 9])
def test_warmup_chain_7(self): self.assertEqual(self._chain(7), [5, 5, 5, 5, 9, 9])
def test_warmup_chain_8(self): self.assertEqual(self._chain(8), [5, 5, 5, 5, 9, 9, 9])
def test_cold_chain_2(self): self.assertEqual(self._chain(2, False), [6])
def test_cold_chain_3(self): self.assertEqual(self._chain(3, False), [6, 5])
def test_cold_chain_4(self): self.assertEqual(self._chain(4, False), [6, 5, 5])
def test_cold_chain_5(self): self.assertEqual(self._chain(5, False), [6, 5, 5, 9])
def test_cold_chain_6(self): self.assertEqual(self._chain(6, False), [6, 5, 5, 9, 9])
def test_cold_chain_7(self): self.assertEqual(self._chain(7, False), [6, 5, 5, 5, 9, 9])
def test_cold_chain_8(self): self.assertEqual(self._chain(8, False), [6, 5, 5, 5, 9, 9, 9])
class TestVALUIndependent(unittest.TestCase):
"""Independent VALU instructions."""
def _ind(self, n, expected_exec):
instrs = [v_mov_b32_e32(v[i], float(i)) for i in range(n)]
issue, execd = get_deltas(instrs)
self.assertEqual(issue[:n-1], [1]*(n-1))
self.assertEqual(execd, expected_exec)
def test_ind_2(self): self._ind(2, [1])
def test_ind_3(self): self._ind(3, [1, 1])
def test_ind_4(self): self._ind(4, [1, 1, 1])
def test_ind_5(self): self._ind(5, [1, 1, 1, 1])
def test_ind_6(self): self._ind(6, [1, 1, 1, 1, 1])
def test_ind_7(self): self._ind(7, [1, 1, 1, 1, 1, 1])
def test_ind_8(self): self._ind(8, [1, 1, 1, 1, 1, 1, 1])
class TestForwardingGap(unittest.TestCase):
"""Producer + N independent instructions + consumer - tests forwarding window."""
def _exec_deltas(self, n_gap):
instrs = [v_mov_b32_e32(v[0], 1.0)]
instrs += [v_mov_b32_e32(v[10+i], float(i)) for i in range(n_gap)]
instrs += [v_add_f32_e32(v[1], v[0], v[0])]
_, execd = get_deltas(instrs)
return execd
def test_gap0(self): self.assertEqual(self._exec_deltas(0), [6])
def test_gap1(self): self.assertEqual(self._exec_deltas(1), [1, 5])
def test_gap2(self): self.assertEqual(self._exec_deltas(2), [1, 1, 4])
def test_gap3(self): self.assertIn(self._exec_deltas(3), [[1, 1, 1, 3], [1, 1, 1, 4]])
def test_gap4(self): self.assertIn(self._exec_deltas(4), [[1, 1, 1, 1, 3], [1, 1, 1, 1, 4]])
def test_gap5(self): self.assertEqual(self._exec_deltas(5), [1, 1, 1, 1, 1, 4]) # anomaly
def test_gap6(self): self.assertEqual(self._exec_deltas(6), [1, 1, 1, 1, 1, 1, 3])
def test_gap7(self): self.assertEqual(self._exec_deltas(7), [1, 1, 1, 1, 1, 1, 1, 3])
def test_gap8(self): self.assertEqual(self._exec_deltas(8), [1, 1, 1, 1, 1, 1, 1, 1, 3])
def test_gap9(self): self.assertEqual(self._exec_deltas(9), [1, 1, 1, 1, 1, 1, 1, 1, 1, 3])
class TestChainWithIndependentGap(unittest.TestCase):
"""Chain of dependent VALUs with independent VALUs inserted before the last one.
Hardware observation: In a chain v0->v1->v2->v3->v4, if we insert N independent VALUs
before v4, the forwarding behavior changes:
- 0-1 independent VALUs: v4 cannot forward from v3 (delta=9)
- 2+ independent VALUs: v4 can forward from v3 (delta=5)
This suggests forwarding eligibility depends on whether the direct source is in the ALU
at issue time, not just at dispatch time.
"""
def _chain5_gap(self, n_ind):
"""Chain v0->v1->v2->v3->v4 with N independent VALUs before v4. Returns v3->v4 delta."""
instrs = [s_nop(100),
v_mov_b32_e32(v[0], 1.0),
v_mov_b32_e32(v[1], v[0]),
v_mov_b32_e32(v[2], v[1]),
v_mov_b32_e32(v[3], v[2])]
instrs += [v_mov_b32_e32(v[10+i], float(i)) for i in range(n_ind)]
instrs += [v_mov_b32_e32(v[4], v[3])]
_, execd = get_deltas(instrs)
# Chain execs are at indices 0,1,2,3 and last one. Independent ones are in between.
# v3->v4 delta = last exec time - 4th exec time (index 3)
# With n_ind independent VALUs, execd has 4 + n_ind entries
# We want delta between exec[3] (v3) and exec[4+n_ind-1] (v4)
# Actually execd is already deltas, so we need absolute times
time, exec_times = 0, []
packets = run_sqtt(instrs)
for ptype, delta in get_timing_deltas(packets):
time += delta
if ptype == 'ALUEXEC': exec_times.append(time)
# v0,v1,v2,v3 are first 4, v4 is last
return exec_times[-1] - exec_times[3]
def test_gap0(self): self.assertEqual(self._chain5_gap(0), 9)
def test_gap1(self): self.assertEqual(self._chain5_gap(1), 9)
def test_gap2(self): self.assertEqual(self._chain5_gap(2), 5)
def test_gap3(self): self.assertEqual(self._chain5_gap(3), 5)
def test_gap4(self): self.assertEqual(self._chain5_gap(4), 5)
class TestVALULatency(unittest.TestCase):
"""VALU latency depends on VGPR source reads.
6 cycles: no VGPR source (constant only), stays 6 regardless of warmup
8-11 cycles: VGPR source read, decreases with warmup (11->10->9->8)
s_nop(0) after VALU immediately drops VGPR read latency to 8
Anomalies:
- 7 consecutive VALUs (no s_nop) causes +1 cycle penalty
- n=0 or n=3 const VALUs + nop + vgpr = 9 cycles (not 8)
"""
def _get_latency(self, instrs):
if not isinstance(instrs, list): instrs = [instrs]
packets = run_sqtt(instrs)
deltas = get_timing_deltas(packets)
time, valu_times, exec_times = 0, [], []
for ptype, delta in deltas:
time += delta
if ptype == 'VALUINST': valu_times.append(time)
if ptype == 'ALUEXEC': exec_times.append(time)
return exec_times[-1] - valu_times[-1] if valu_times and exec_times else None
# 6-cycle latency: no VGPR source (constant), always 6
def test_const_single(self): self.assertEqual(self._get_latency(v_mov_b32_e32(v[0], 1.0)), 6)
def test_const_literal(self): self.assertEqual(self._get_latency(v_mov_b32_e32(v[0], 565.0)), 6)
def test_const_after_const(self): self.assertEqual(self._get_latency([v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[1], 2.0)]), 6)
def test_const_after_nop(self): self.assertEqual(self._get_latency([v_mov_b32_e32(v[0], 1.0), s_nop(0), v_mov_b32_e32(v[1], 2.0)]), 6)
# VGPR read latency: cold start = 9
def test_vgpr_cold(self): self.assertEqual(self._get_latency(v_mov_b32_e32(v[0], v[1])), 9)
# VGPR read latency: warmup decreases 11->10->9->8
def _vgpr_after_n_const(self, n):
return self._get_latency([v_mov_b32_e32(v[i], float(i)) for i in range(n)] + [v_mov_b32_e32(v[10], v[99])])
def test_vgpr_after_1_const(self): self.assertEqual(self._vgpr_after_n_const(1), 11)
def test_vgpr_after_2_const(self): self.assertEqual(self._vgpr_after_n_const(2), 10)
def test_vgpr_after_3_const(self): self.assertEqual(self._vgpr_after_n_const(3), 9)
def test_vgpr_after_4_const(self): self.assertIn(self._vgpr_after_n_const(4), [8, 9])
def test_vgpr_after_5_const(self): self.assertIn(self._vgpr_after_n_const(5), [8, 9])
def test_vgpr_after_6_const(self): self.assertEqual(self._vgpr_after_n_const(6), 9) # anomaly
def test_vgpr_after_7_const(self): self.assertEqual(self._vgpr_after_n_const(7), 8)
def test_vgpr_after_8_const(self): self.assertEqual(self._vgpr_after_n_const(8), 8)
# s_nop(0) immediately drops VGPR read latency to 8 (or 9 on some variants)
def test_vgpr_nop_warmup(self): self.assertIn(self._get_latency([v_mov_b32_e32(v[0], 1.0), s_nop(0), v_mov_b32_e32(v[1], v[99])]), [8, 9])
# s_nop + vgpr read: latency depends on # of const VALUs before nop
def _n_const_nop_vgpr(self, n):
"""N const VALUs + s_nop(0) + vgpr read."""
instrs = [v_mov_b32_e32(v[i], float(i)) for i in range(n)]
instrs += [s_nop(0)]
instrs += [v_mov_b32_e32(v[10], v[99])]
return self._get_latency(instrs)
def test_0_const_nop_vgpr(self): self.assertEqual(self._n_const_nop_vgpr(0), 9)
def test_1_const_nop_vgpr(self): self.assertIn(self._n_const_nop_vgpr(1), [8, 9])
def test_2_const_nop_vgpr(self): self.assertIn(self._n_const_nop_vgpr(2), [8, 9])
def test_3_const_nop_vgpr(self): self.assertEqual(self._n_const_nop_vgpr(3), 9) # anomaly
def test_4_const_nop_vgpr(self): self.assertEqual(self._n_const_nop_vgpr(4), 8)
def test_5_const_nop_vgpr(self): self.assertEqual(self._n_const_nop_vgpr(5), 8)
def test_6_const_nop_vgpr(self): self.assertEqual(self._n_const_nop_vgpr(6), 8)
def test_7_const_nop_vgpr(self): self.assertEqual(self._n_const_nop_vgpr(7), 8)
class TestChainWithNop(unittest.TestCase):
"""Dependency chain with s_nop between instructions."""
def _test(self, nop_val, expected_issue, expected_exec):
issue, execd = get_deltas([v_mov_b32_e32(v[0], 1.0), s_nop(nop_val), v_add_f32_e32(v[1], v[0], v[0])])
self.assertEqual(issue[:2], expected_issue)
if isinstance(expected_exec[0], list): self.assertIn(execd, expected_exec)
else: self.assertEqual(execd, expected_exec)
def test_nop0(self): self._test(0, [3, 1], [[6], [7]])
def test_nop1(self): self._test(1, [4, 1], [[7], [8]])
def test_nop2(self): self._test(2, [5, 1], [9])
def test_nop3(self): self._test(3, [6, 1], [9])
def test_nop4(self): self._test(4, [11, 1], [10])
def test_nop5(self): self._test(5, [12, 1], [11])
class TestIndWithNop(unittest.TestCase):
"""Independent instructions with s_nop between."""
def _test(self, nop_val, expected_issue, expected_exec):
issue, execd = get_deltas([v_mov_b32_e32(v[0], 1.0), s_nop(nop_val), v_mov_b32_e32(v[1], 2.0)])
self.assertEqual(issue[:2], expected_issue)
self.assertEqual(execd, expected_exec)
def test_nop0(self): self._test(0, [3, 1], [4])
def test_nop1(self): self._test(1, [4, 1], [5])
def test_nop3(self): self._test(3, [6, 1], [7])
def test_nop4(self): self._test(4, [11, 1], [8])
def test_nop5(self): self._test(5, [12, 1], [9])
class TestChain3NopMid(unittest.TestCase):
"""3-instruction chain with s_nop in middle."""
def _test(self, nop_val, expected_issue, expected_exec):
issue, execd = get_deltas([
v_mov_b32_e32(v[0], 1.0), v_add_f32_e32(v[1], v[0], v[0]),
s_nop(nop_val), v_add_f32_e32(v[2], v[1], v[1])])
self.assertEqual(issue[:3], expected_issue)
self.assertEqual(execd, expected_exec)
def test_nop0(self): self._test(0, [1, 3, 1], [6, 5])
def test_nop1(self): self._test(1, [1, 4, 1], [6, 5])
def test_nop2(self): self._test(2, [1, 5, 1], [6, 5])
def test_nop3(self): self._test(3, [1, 10, 1], [6, 5])
class TestInd3NopMid(unittest.TestCase):
"""3 independent instructions with s_nop in middle."""
def _test(self, nop_val, expected_issue, expected_exec):
issue, execd = get_deltas([
v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[1], 2.0),
s_nop(nop_val), v_mov_b32_e32(v[2], 3.0)])
self.assertEqual(issue[:3], expected_issue)
self.assertEqual(execd, expected_exec)
def test_nop0(self): self._test(0, [1, 3, 1], [1, 4])
def test_nop1(self): self._test(1, [1, 4, 1], [1, 5])
def test_nop2(self): self._test(2, [1, 5, 1], [1, 6])
def test_nop3(self): self._test(3, [1, 10, 1], [1, 7])
class TestSNopDelay(unittest.TestCase):
"""Single s_nop delay between two independent v_movs.
s_nop(n) delays n+1 cycles, plus +4 extra for n in [11, 22].
Exec delta = n + 4 (baseline) + 4 (if 11 <= n <= 22)."""
def _test(self, n, expected):
_, execd = get_deltas([v_mov_b32_e32(v[0], 1.0), s_nop(n), v_mov_b32_e32(v[1], 2.0)])
if isinstance(expected, list): self.assertIn(execd[0], expected)
else: self.assertEqual(execd, [expected])
def test_snop_0(self): self._test(0, 4)
def test_snop_1(self): self._test(1, 5)
def test_snop_2(self): self._test(2, 6)
def test_snop_3(self): self._test(3, 7)
def test_snop_4(self): self._test(4, 8)
def test_snop_5(self): self._test(5, 9)
def test_snop_6(self): self._test(6, 10)
def test_snop_7(self): self._test(7, 11)
def test_snop_10(self): self._test(10, 14)
def test_snop_11(self): self._test(11, 19) # +4 extra starts here
def test_snop_15(self): self._test(15, 23)
def test_snop_22(self): self._test(22, 30) # +4 extra ends here
def test_snop_23(self): self._test(23, 27)
def test_snop_31(self): self._test(31, 35)
def test_snop_32(self): self._test(32, 36)
def test_snop_63(self): self._test(63, [67, 71])
class TestVALUExecWithNop(unittest.TestCase):
"""Single VALU followed by s_nop - measures VALUINST to ALUEXEC delay."""
def _get_delay(self, instrs, nops=16):
deltas = get_timing_deltas(run_sqtt(instrs, nops=nops))
time, valu_time, exec_time = 0, None, None
for ptype, delta in deltas:
time += delta
if ptype == 'VALUINST' and valu_time is None: valu_time = time
if ptype == 'ALUEXEC' and exec_time is None: exec_time = time
return exec_time - valu_time
# Boundary: s_nop(0-3) = 6 cycles, s_nop(4+) = 10 cycles
def test_nop0(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(0)]), 6)
def test_nop1(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(1)]), 6)
def test_nop2(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(2)]), 6)
def test_nop3(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(3)]), 6)
def test_nop4(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(4)]), 10)
def test_nop5(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(5)]), 10)
def test_nop6(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(6)]), 10)
def test_nop7(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(7)]), 10)
def test_nop8(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(8)]), 10)
def test_nop9(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(9)]), 10)
def test_nop10(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(10)]), 10)
# No nop = slow path, one s_nop(0) padding = fast path
def test_no_padding(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0)], nops=0), 10)
def test_one_padding(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0)], nops=1), 6)
# Multiple s_nop(0)s don't accumulate - still fast path
def test_nop0_x2(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(0), s_nop(0)]), 6)
# First nop determines path: s_nop(0) then s_nop(4) = fast, s_nop(4) then s_nop(0) = slow
def test_nop0_nop4(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(0), s_nop(4)]), 6)
def test_nop4_nop0(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(4), s_nop(0)]), 10)
class TestDelayALU(unittest.TestCase):
"""s_delay_alu behavior - helps understand hardware pipeline latencies.
s_delay_alu(simm16) where simm16 encodes:
instid0[3:0] = dependency on VALU N instructions back (1-4), 0=none
skip[6:4] = skip count for second dependency
instid1[10:7] = second dependency
Key insight: s_delay_alu tells hardware to wait for a previous VALU to complete.
The hardware determines how many cycles to stall based on pipeline state.
"""
def _exec_delta(self, instrs):
"""Return exec delta for last instruction."""
_, execd = get_deltas(instrs)
return execd[-1] if execd else None
# Direct dependency (producer -> consumer), instid0=1 means "wait for VALU 1 back"
def test_direct_no_delay(self):
# Without s_delay_alu: 6 cycles
self.assertEqual(self._exec_delta([v_mov_b32_e32(v[0], 1.0), v_add_f32_e32(v[1], v[0], v[0])]), 6)
def test_direct_delay1(self):
# With s_delay_alu(instid0=1): 7-8 cycles (+1 from the delay instruction)
self.assertIn(self._exec_delta([v_mov_b32_e32(v[0], 1.0), s_delay_alu(simm16=1), v_add_f32_e32(v[1], v[0], v[0])]), [7, 8])
def test_direct_delay2(self):
# instid0=2 doesn't apply (only 1 VALU back), so no extra delay
self.assertEqual(self._exec_delta([v_mov_b32_e32(v[0], 1.0), s_delay_alu(simm16=2), v_add_f32_e32(v[1], v[0], v[0])]), 6)
def test_direct_delay3(self):
self.assertEqual(self._exec_delta([v_mov_b32_e32(v[0], 1.0), s_delay_alu(simm16=3), v_add_f32_e32(v[1], v[0], v[0])]), 6)
def test_direct_delay4(self):
self.assertEqual(self._exec_delta([v_mov_b32_e32(v[0], 1.0), s_delay_alu(simm16=4), v_add_f32_e32(v[1], v[0], v[0])]), 6)
# With 1 independent instruction between producer and consumer
def test_gap1_delay1(self):
# instid0=1 waits for the independent instruction (not the producer)
instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), s_delay_alu(simm16=1), v_add_f32_e32(v[1], v[0], v[0])]
self.assertEqual(self._exec_delta(instrs), 8)
def test_gap1_delay2(self):
# instid0=2 waits for the producer (2 VALUs back)
instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), s_delay_alu(simm16=2), v_add_f32_e32(v[1], v[0], v[0])]
self.assertIn(self._exec_delta(instrs), [6, 7])
def test_gap1_delay3(self):
# instid0=3 doesn't apply (only 2 VALUs back)
instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), s_delay_alu(simm16=3), v_add_f32_e32(v[1], v[0], v[0])]
self.assertEqual(self._exec_delta(instrs), 5)
# With 2 independent instructions between
def test_gap2_delay1(self):
instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), v_mov_b32_e32(v[6], 6.0),
s_delay_alu(simm16=1), v_add_f32_e32(v[1], v[0], v[0])]
self.assertEqual(self._exec_delta(instrs), 7)
def test_gap2_delay2(self):
instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), v_mov_b32_e32(v[6], 6.0),
s_delay_alu(simm16=2), v_add_f32_e32(v[1], v[0], v[0])]
self.assertEqual(self._exec_delta(instrs), 7)
def test_gap2_delay3(self):
# instid0=3 waits for the producer (3 VALUs back)
instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), v_mov_b32_e32(v[6], 6.0),
s_delay_alu(simm16=3), v_add_f32_e32(v[1], v[0], v[0])]
self.assertIn(self._exec_delta(instrs), [5, 6])
def test_gap2_delay4(self):
instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), v_mov_b32_e32(v[6], 6.0),
s_delay_alu(simm16=4), v_add_f32_e32(v[1], v[0], v[0])]
self.assertEqual(self._exec_delta(instrs), 4)
class TestNopTimingSensitivity(unittest.TestCase):
"""Forwarding behavior has 128-cycle periodicity.
Hardware observation: when nop_cycles % 128 is in [72, 75], chain_6 gets 5 forwards
instead of 4. This 4-cycle window repeats every 128 cycles, suggesting alignment
with some hardware scheduling period (possibly wave scheduler or cache).
Windows found: nop 72-75, 200-203, 328-331, 456-459, ...
"""
def _chain6_fwd_count(self, nop_size):
"""Count initial consecutive forwards for a 6-instruction chain after s_nop(n)."""
instrs = [s_nop(nop_size), v_mov_b32_e32(v[99], 1.0)]
instrs += [v_mov_b32_e32(v[0], 1.0)]
for i in range(1, 6):
instrs += [v_mov_b32_e32(v[i], v[i-1])]
_, execd = get_deltas(instrs)
chain_deltas = execd[1:]
fwd_count = 0
for d in chain_deltas:
if d == 5: fwd_count += 1
else: break
return fwd_count
# Normal case: 4 forwards
def test_nop71(self): self.assertEqual(self._chain6_fwd_count(71), 4)
def test_nop76(self): self.assertEqual(self._chain6_fwd_count(76), 4)
def test_nop199(self): self.assertEqual(self._chain6_fwd_count(199), 4)
def test_nop204(self): self.assertEqual(self._chain6_fwd_count(204), 4)
# Anomaly window at nop % 128 == 72-75: 5 forwards on RDNA3, 4 on other variants
def test_nop72(self): self.assertIn(self._chain6_fwd_count(72), [4, 5])
def test_nop75(self): self.assertIn(self._chain6_fwd_count(75), [4, 5])
def test_nop200(self): self.assertIn(self._chain6_fwd_count(200), [4, 5])
def test_nop203(self): self.assertIn(self._chain6_fwd_count(203), [4, 5])
def test_nop328(self): self.assertIn(self._chain6_fwd_count(328), [4, 5])
def test_nop331(self): self.assertIn(self._chain6_fwd_count(331), [4, 5])
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,463 @@
#!/usr/bin/env python3
"""Hardware tests for SQTT decoder - validates decoding of real SQTT streams.
Run with: python -m pytest extra/assembly/amd/test/test_sqtt_hw.py -v -s
Requires AMD GPU with SQTT support.
For pretty trace output: DEBUG=2 python -m pytest extra/assembly/amd/test/test_sqtt_hw.py -v -s
"""
import os
os.environ["SQTT"] = "1"
os.environ["PROFILE"] = "1"
os.environ["SQTT_ITRACE_SE_MASK"] = "1" # Enable instruction tracing on SE0
os.environ["SQTT_LIMIT_SE"] = "2" # Force work to traced SE only
import unittest
from tinygrad.helpers import DEBUG, colored
from tinygrad.device import Device
from tinygrad.runtime.ops_amd import AMDProgram, ProfileSQTTEvent
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, v_mul_f32_e32, s_mov_b32, s_add_u32, s_nop, s_waitcnt, s_endpgm
from extra.assembly.amd.dsl import v, s
from extra.assembly.amd.sqtt import decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, VALUINST, ALUEXEC, VMEMEXEC, InstOp, AluSrc, MemSrc
dev = Device["AMD"]
# ═══════════════════════════════════════════════════════════════════════════════
# PRETTY PRINTING
# ═══════════════════════════════════════════════════════════════════════════════
PACKET_COLORS = {
"INST": "WHITE", "VALUINST": "BLACK",
"VMEMEXEC": "yellow", "ALUEXEC": "yellow",
"IMMEDIATE": "YELLOW", "IMMEDIATE_MASK": "YELLOW",
"WAVERDY": "cyan", "WAVEALLOC": "cyan",
"WAVEEND": "blue", "WAVESTART": "blue",
"PERF": "magenta",
"EVENT": "red", "EVENT_BIG": "red",
"REG": "green",
"LAYOUT_HEADER": "white",
"TS_DELTA_SHORT": "BLACK", "NOP": "BLACK", "TS_WAVE_STATE": "BLACK",
"SNAPSHOT": "white", "TS_DELTA_OR_MARK": "BLACK",
"TS_DELTA_S8_W3": "BLACK", "TS_DELTA_S5_W2": "BLACK", "TS_DELTA_S5_W3": "BLACK",
"UTILCTR": "green",
}
def format_packet(p, last_time: int = 0, time_offset: int = 0) -> str:
"""Format a packet for pretty printing."""
name = type(p).__name__
color = PACKET_COLORS.get(name, "white")
fields = []
if isinstance(p, INST):
op = p.op
op_name = op.name if isinstance(op, InstOp) else f"0x{op:02x}"
fields = [f"wave={p.wave}", f"op={op_name}"]
if p.flag1: fields.append("flag1")
if p.flag2: fields.append("flag2")
elif isinstance(p, VALUINST):
fields = [f"wave={p.wave}"]
if p.flag: fields.append("flag")
elif isinstance(p, ALUEXEC):
src_name = p.src.name if isinstance(p.src, AluSrc) else f"{p.src}"
fields = [f"src={src_name}"]
elif isinstance(p, VMEMEXEC):
src_name = p.src.name if isinstance(p.src, MemSrc) else f"{p.src}"
fields = [f"src={src_name}"]
elif isinstance(p, WAVESTART):
fields = [f"wave={p.wave}", f"simd={p.simd}", f"cu={p.cu}"]
elif isinstance(p, WAVEEND):
fields = [f"wave={p.wave}", f"simd={p.simd}", f"cu={p.cu}"]
elif hasattr(p, '_values'):
# Format hex fields appropriately
hex_fields = {'snap', 'val32'}
fields = [f"{k}=0x{v:x}" if k in hex_fields else f"{k}={v}" for k, v in p._values.items() if not k.startswith('_') and k != 'delta']
return colored(f"{name:18s}", color) + " " + ", ".join(fields)
def get_wave_packets(packets: list) -> list:
"""Extract packets from WAVESTART to WAVEEND, filtering pure timing packets."""
skip_types = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3"}
result = []
in_wave = False
for p in packets:
name = type(p).__name__
if isinstance(p, WAVESTART):
in_wave = True
if in_wave and name not in skip_types:
result.append(p)
if isinstance(p, WAVEEND):
in_wave = False
return result
def print_wave_trace(packets: list) -> None:
"""Print packets from WAVESTART to WAVEEND with normalized time."""
wave_packets = get_wave_packets(packets)
if not wave_packets:
return
time_offset = wave_packets[0]._time
last_time = time_offset
for p in wave_packets:
print(format_packet(p, last_time, time_offset))
last_time = p._time
def print_blobs(blobs: list[bytes], wave_only: bool = True) -> None:
"""Print traces for all blobs. wave_only=True filters to WAVESTART..WAVEEND only."""
for i, blob in enumerate(blobs):
packets = decode(blob)
print(f"\n--- Blob {i}: {len(blob)} bytes, {len(packets)} packets ---")
if wave_only:
print_wave_trace(packets)
else:
print_all_packets(packets)
def print_all_packets(packets: list) -> None:
"""Print all packets, filtering out pure timing packets."""
skip_types = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3"}
if not packets: return
time_offset = packets[0]._time
last_time = time_offset
for p in packets:
if type(p).__name__ not in skip_types:
print(format_packet(p, last_time, time_offset))
last_time = p._time
# ═══════════════════════════════════════════════════════════════════════════════
# ASSEMBLY HELPERS
# ═══════════════════════════════════════════════════════════════════════════════
def assemble(instructions: list) -> bytes:
return b''.join(inst.to_bytes() for inst in instructions)
def wrap_with_nops(instructions: list, nops=16) -> list:
"""Add epilogue for clean SQTT timing.
Need enough NOPs to cover long-latency ops (DP: 42 cycles, WMMA: 47 cycles).
With 64 NOPs, the IMMEDIATE phase extends to cover these completions.
"""
return instructions + [s_nop(0)]*nops + [s_endpgm()]
def compile_asm_sqtt(instructions: list, alu_only: bool = False) -> AMDProgram:
"""Compile instructions to an AMDProgram for SQTT tracing.
Args:
instructions: List of instructions to compile
alu_only: If True, use minimal kernel config with no kernargs/LDS/scratch
Returns:
Compiled AMDProgram ready to run
"""
compiler = HIPCompiler(dev.arch)
# Add NOPs before s_endpgm to flush pipeline and get clean timing
code = assemble(instructions)
byte_str = ', '.join(f'0x{b:02x}' for b in code)
if alu_only:
asm_src = f""".text
.globl test
.p2align 8
.type test,@function
test:
.byte {byte_str}
.rodata
.p2align 6
.amdhsa_kernel test
# basic memory
.amdhsa_group_segment_fixed_size 0
.amdhsa_private_segment_fixed_size 0
.amdhsa_kernarg_size 32
.amdhsa_enable_private_segment 0
# register usage
.amdhsa_next_free_vgpr 64
.amdhsa_next_free_sgpr 8
# RSRC1
.amdhsa_wavefront_size32 1
.amdhsa_memory_ordered 1
.amdhsa_forward_progress 1
# this is key
.amdhsa_workgroup_processor_mode 0
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.version:
- 1
- 0
amdhsa.kernels:
- .name: test
.symbol: test.kd
.kernarg_segment_size: 0
.group_segment_fixed_size: 0
.private_segment_fixed_size: 0
.kernarg_segment_align: 8
.wavefront_size: 32
.sgpr_count: 8
.vgpr_count: 64
.max_flat_workgroup_size: 1024
...
.end_amdgpu_metadata
"""
else:
asm_src = f""".text
.globl test
.p2align 8
.type test,@function
test:
.byte {byte_str}
.rodata
.p2align 6
.amdhsa_kernel test
.amdhsa_next_free_vgpr 8
.amdhsa_next_free_sgpr 16
.amdhsa_wavefront_size32 1
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_kernarg_size 8
.amdhsa_group_segment_fixed_size 0
.amdhsa_private_segment_fixed_size 0
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.version:
- 1
- 0
amdhsa.kernels:
- .name: test
.symbol: test.kd
.kernarg_segment_size: 8
.group_segment_fixed_size: 0
.private_segment_fixed_size: 0
.kernarg_segment_align: 8
.wavefront_size: 32
.sgpr_count: 16
.vgpr_count: 8
.max_flat_workgroup_size: 1024
...
.end_amdgpu_metadata
"""
lib = compiler.compile(asm_src)
return AMDProgram(dev, "test", lib)
def run_asm_sqtt(instructions: list, n_lanes: int = 1, alu_only: bool = False) -> list[bytes]:
"""Compile and run instructions on AMD hardware, return SQTT blobs.
Args:
instructions: List of instructions to run
n_lanes: Number of lanes to use
alu_only: If True, use minimal kernel config with no kernargs/LDS/scratch
"""
prg = compile_asm_sqtt(instructions, alu_only=alu_only)
return run_prg_sqtt(prg, n_lanes=n_lanes, alu_only=alu_only)
def run_prg_sqtt(prg: AMDProgram, n_lanes: int = 1, alu_only: bool = False) -> list[bytes]:
"""Run a compiled AMDProgram and return SQTT blobs.
Args:
prg: Compiled AMDProgram to run
n_lanes: Number of lanes to use
alu_only: If True, don't allocate kernarg buffer
"""
dev.profile_events.clear()
if alu_only:
prg(global_size=(1, 1, 1), local_size=(n_lanes, 1, 1), wait=True)
else:
out_gpu = dev.allocator.alloc(2048)
prg(out_gpu, global_size=(1, 1, 1), local_size=(n_lanes, 1, 1), wait=True)
return [ev.blob for ev in dev.profile_events if isinstance(ev, ProfileSQTTEvent)]
def run_prg_sqtt_batch(prg: AMDProgram, n_runs: int, n_lanes: int = 1) -> list[bytes]:
"""Run a compiled AMDProgram N times in a single queue submission and return SQTT blobs.
This builds one queue with N kernel executions, submits it once, and collects SQTT.
All N runs are captured in the same SQTT trace, reducing startup jitter.
Args:
prg: Compiled AMDProgram to run
n_runs: Number of times to execute the kernel in the queue
n_lanes: Number of lanes to use
Returns:
List of SQTT blobs (one per shader engine)
"""
from typing import cast
from tinygrad.runtime.ops_amd import AMDComputeQueue, SQTT_ITRACE_SE_MASK
from tinygrad.device import Compiled
import struct
dev.profile_events.clear()
# Build queue with sqtt_start, N kernel executions, sqtt_stop
kernargs = prg.fill_kernargs([], ())
q = cast(AMDComputeQueue, dev.hw_compute_queue_t())
q.wait(dev.timeline_signal, dev.timeline_value - 1).memory_barrier()
q.sqtt_start(dev.sqtt_buffers)
# Execute kernel N times
for _ in range(n_runs):
q.exec(prg, kernargs, (1, 1, 1), (n_lanes, 1, 1))
q.sqtt_stop(dev.sqtt_wptrs)
q.signal(dev.timeline_signal, dev.next_timeline())
q.submit(dev)
dev.synchronize()
# Collect SQTT blobs
blobs = []
for se, buf in enumerate(dev.sqtt_buffers):
wptr = (dev.sqtt_wptrs.cpu_view().view(fmt='I')[se] & 0x1FFFFFFF) * 32
if dev.target[:2] == (11, 0): wptr -= ((buf.va_addr // 32) & 0x1FFFFFFF) * 32
if wptr > 0 and wptr <= buf.size:
dev.allocator._copyout(sqtt_mv:=memoryview(bytearray(wptr)), buf)
resbuf = (struct.pack('<Q', 0x11 | (4 << 13) | (0xf << 16) | (se << 24)) + bytes(sqtt_mv)) if dev.target[0] == 9 else bytes(sqtt_mv)
blobs.append(resbuf)
return blobs
def decode_all_blobs(blobs: list[bytes]) -> list:
"""Decode all blobs and combine packets."""
all_packets = []
for blob in blobs:
all_packets.extend(decode(blob))
return all_packets
def get_inst_ops(packets: list, traced_simd: int | None = None) -> set:
"""Extract all InstOp values from INST packets within WAVESTART..WAVEEND on traced SIMD."""
ops = set()
in_wave = False
for p in packets:
if isinstance(p, WAVESTART):
in_wave = traced_simd is None or p.simd == traced_simd
if in_wave and isinstance(p, INST):
ops.add(p.op if isinstance(p.op, int) else p.op.value)
if isinstance(p, WAVEEND):
in_wave = False
return ops
def count_valuinst(packets: list, traced_simd: int | None = None) -> int:
"""Count VALUINST packets within WAVESTART..WAVEEND on traced SIMD."""
count = 0
in_wave = False
for p in packets:
if isinstance(p, WAVESTART):
in_wave = traced_simd is None or p.simd == traced_simd
if in_wave and isinstance(p, VALUINST):
count += 1
if isinstance(p, WAVEEND):
in_wave = False
return count
# ═══════════════════════════════════════════════════════════════════════════════
# TESTS
# ═══════════════════════════════════════════════════════════════════════════════
@unittest.skipIf(not hasattr(dev, 'profile_events'), "AMD device required")
class TestSQTTDecode(unittest.TestCase):
"""Test SQTT decoder with real hardware traces."""
def test_basic_structure(self):
"""Verify basic SQTT stream structure: LAYOUT_HEADER, WAVESTART, instructions, WAVEEND."""
blobs = run_asm_sqtt([v_mov_b32_e32(v[0], 0)])
self.assertGreater(len(blobs), 0, "No SQTT data captured")
packets = decode_all_blobs(blobs)
self.assertGreater(len(packets), 0, "No packets decoded")
self.assertGreater(len([p for p in packets if isinstance(p, LAYOUT_HEADER)]), 0, "No LAYOUT_HEADER packets")
self.assertGreater(len([p for p in packets if isinstance(p, WAVESTART)]), 0, "No WAVESTART packets")
self.assertGreater(len([p for p in packets if isinstance(p, WAVEEND)]), 0, "No WAVEEND packets")
if DEBUG >= 2:
print("\n=== Basic structure trace ===")
print_trace(packets)
def test_valu_instructions(self):
"""Verify VALU instructions produce INST or VALUINST packets."""
instructions = [
v_mov_b32_e32(v[0], 1.0),
v_mov_b32_e32(v[1], 2.0),
v_add_f32_e32(v[2], v[0], v[1]),
v_add_f32_e32(v[3], v[2], v[1]),
v_mul_f32_e32(v[4], v[2], v[3]),
]
blobs = run_asm_sqtt(instructions)
self.assertGreater(len(blobs), 0, "No SQTT data captured")
packets = decode_all_blobs(blobs)
inst_packets = [p for p in packets if isinstance(p, (INST, VALUINST))]
self.assertGreater(len(inst_packets), 0, "No INST/VALUINST packets for VALU instructions")
if DEBUG >= 2:
print("\n=== VALU instructions trace ===")
print_trace(packets)
def test_salu_instructions(self):
"""Verify SALU instructions produce appropriate packets."""
instructions = [
s_mov_b32(s[0], 0),
s_mov_b32(s[1], 1),
s_add_u32(s[2], s[0], s[1]),
s_add_u32(s[3], s[2], s[1]),
s_nop(0),
]
blobs = run_asm_sqtt(instructions)
self.assertGreater(len(blobs), 0, "No SQTT data captured")
packets = decode_all_blobs(blobs)
if DEBUG >= 2:
print("\n=== SALU instructions trace ===")
print_trace(packets)
def test_timing_increases(self):
"""Verify time increases monotonically through packets within each blob."""
instructions = [
v_mov_b32_e32(v[0], 1.0),
v_mov_b32_e32(v[1], 2.0),
v_add_f32_e32(v[2], v[0], v[1]),
v_mul_f32_e32(v[3], v[2], v[1]),
]
blobs = run_asm_sqtt(instructions)
self.assertGreater(len(blobs), 0, "No SQTT data captured")
for blob in blobs:
packets = decode(blob)
prev_time = 0
for p in packets:
self.assertGreaterEqual(p._time, prev_time, f"Time decreased: {prev_time} -> {p._time}")
prev_time = p._time
def test_wave_id_consistency(self):
"""Verify wave IDs are consistent between WAVESTART/WAVEEND."""
blobs = run_asm_sqtt([v_mov_b32_e32(v[0], 0)])
self.assertGreater(len(blobs), 0, "No SQTT data captured")
packets = decode_all_blobs(blobs)
wavestarts = [p for p in packets if isinstance(p, WAVESTART)]
waveends = [p for p in packets if isinstance(p, WAVEEND)]
if wavestarts and waveends:
start_waves = {p.wave for p in wavestarts}
end_waves = {p.wave for p in waveends}
self.assertTrue(start_waves & end_waves, "No matching wave IDs between WAVESTART and WAVEEND")
def test_nop_sequence(self):
"""Test a sequence of NOP instructions."""
blobs = run_asm_sqtt([s_nop(0), s_nop(0), s_nop(0)])
self.assertGreater(len(blobs), 0, "No SQTT data captured")
packets = decode_all_blobs(blobs)
self.assertGreater(len(packets), 0, "No packets decoded")
if DEBUG >= 2:
print("\n=== NOP sequence trace ===")
print_trace(packets, filter_timing=False)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,35 @@
import os
os.environ["SQTT"] = "1"
os.environ["PROFILE"] = "1"
os.environ["SQTT_LIMIT_SE"] = "2"
os.environ["SQTT_SIMD_SEL"] = "0"
os.environ["SQTT_TOKEN_EXCLUDE"] = "3784" # Exclude WAVERDY, REG, EVENT, UTILCTR, WAVEALLOC, PERF
import unittest
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.sqtt import decode
from extra.assembly.amd.test.test_sqtt_hw import compile_asm_sqtt, run_prg_sqtt_batch, format_packet
from extra.assembly.amd.test.test_sqtt_compare import filter_noise_packets
from tinygrad.uop.ops import UOp
from tinygrad.engine.realize import get_runner
class SQTTMultiwave(unittest.TestCase):
def test_simple_multiwave(self):
ins = [
s_barrier(),
v_mov_b32_e32(v[0], v[1]),
s_nop(0),
s_nop(100),
s_endpgm(),
]
#prg = get_runner("AMD", UOp.sink())._prg
prg = compile_asm_sqtt(ins, alu_only=True)
print(prg)
blobs = run_prg_sqtt_batch(prg, n_runs=1, n_lanes=32*16)
for blob in blobs:
packets = decode(blob)
for p in filter_noise_packets(packets):
print(f" {p._time:8d}: {format_packet(p)}")
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,204 @@
#!/usr/bin/env python3
"""Tests validating SQTT packet definitions against the reference implementation.
Verifies that:
1. Encoding patterns produce the correct STATE_TO_OPCODE table
2. Packet sizes (derived from fields) match expected budget values
3. Field extractions match attempt_sqtt_parse.py
"""
import unittest
from extra.assembly.amd.sqtt import (
VALUINST, VMEMEXEC, ALUEXEC, IMMEDIATE, IMMEDIATE_MASK, WAVERDY,
WAVEEND, WAVESTART, PERF, TS_WAVE_STATE, EVENT, EVENT_BIG, REG, SNAPSHOT,
TS_DELTA_OR_MARK, LAYOUT_HEADER, INST, UTILCTR, TS_DELTA_SHORT, NOP,
TS_DELTA_S8_W3, TS_DELTA_S5_W2, TS_DELTA_S5_W3, WAVEALLOC,
decode, encode, OPCODE_TO_CLASS, STATE_TO_OPCODE, PACKET_TYPES, BUDGET,
AluSrc, MemSrc, InstOp
)
# Reference table from rocprof trace decoder (attempt_sqtt_parse.py)
REFERENCE_STATE_TABLE = bytes([
0x10, 0x16, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x17, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
0x10, 0x07, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x19, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
0x10, 0x00, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x11, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
0x10, 0x12, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x15, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
0x10, 0x16, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x17, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
0x10, 0x07, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x19, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
0x10, 0x00, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x11, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
0x10, 0x13, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x15, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
])
# Reference opcode -> name mapping (old opcode values from rocprof)
OLD_OPCODE_TO_NAME = {
0x01: 'VALUINST', 0x02: 'VMEMEXEC', 0x03: 'ALUEXEC', 0x04: 'IMMEDIATE',
0x05: 'IMMEDIATE_MASK', 0x06: 'WAVERDY', 0x07: 'TS_DELTA_S8_W3',
0x08: 'WAVEEND', 0x09: 'WAVESTART', 0x0A: 'TS_DELTA_S5_W2',
0x0B: 'WAVEALLOC', 0x0C: 'TS_DELTA_S5_W3', 0x0D: 'PERF',
0x0F: 'TS_DELTA_SHORT', 0x10: 'NOP', 0x11: 'TS_WAVE_STATE',
0x12: 'EVENT', 0x13: 'EVENT_BIG', 0x14: 'REG', 0x15: 'SNAPSHOT',
0x16: 'TS_DELTA_OR_MARK', 0x17: 'LAYOUT_HEADER', 0x18: 'INST',
0x19: 'UTILCTR', 0x00: 'NOP',
}
# Reference budget values (nibbles for NEXT packet) from rocprof
REFERENCE_BUDGET_NIBBLES = {
'VALUINST': 3, 'VMEMEXEC': 2, 'ALUEXEC': 2, 'IMMEDIATE': 3,
'IMMEDIATE_MASK': 6, 'WAVERDY': 6, 'TS_DELTA_S8_W3': 16,
'WAVEEND': 5, 'WAVESTART': 8, 'TS_DELTA_S5_W2': 12,
'WAVEALLOC': 5, 'TS_DELTA_S5_W3': 13, 'PERF': 7,
'TS_DELTA_SHORT': 2, 'NOP': 1, 'TS_WAVE_STATE': 6,
'EVENT': 6, 'EVENT_BIG': 8, 'REG': 16, 'SNAPSHOT': 16,
'TS_DELTA_OR_MARK': 12, 'LAYOUT_HEADER': 16, 'INST': 5,
'UTILCTR': 12,
}
class TestEncodingsMatchStateTable(unittest.TestCase):
"""Verify encoding patterns produce the correct state decode table."""
def test_all_256_bytes_decode_correctly(self):
"""Each byte value should decode to the same packet type as reference."""
mismatches = []
for byte_val in range(256):
ref_opcode = REFERENCE_STATE_TABLE[byte_val]
ref_name = OLD_OPCODE_TO_NAME.get(ref_opcode, f"UNK_{ref_opcode:02x}")
our_opcode = STATE_TO_OPCODE[byte_val]
our_name = OPCODE_TO_CLASS[our_opcode].__name__
if ref_name != our_name:
mismatches.append((byte_val, ref_name, our_name))
if mismatches:
msg = "\n".join(f" 0x{b:02x}: expected {r}, got {o}" for b, r, o in mismatches[:10])
self.fail(f"State table mismatches ({len(mismatches)} total):\n{msg}")
class TestPacketSizesMatchBudget(unittest.TestCase):
"""Verify packet sizes (from field definitions) match expected budget values."""
def test_all_packet_sizes(self):
"""Each packet type's size should match the reference budget."""
for pkt_cls in PACKET_TYPES:
name = pkt_cls.__name__
expected = REFERENCE_BUDGET_NIBBLES.get(name)
if expected is None:
continue
actual = pkt_cls.size_nibbles()
self.assertEqual(expected, actual,
f"{name}: expected {expected} nibbles, got {actual} (size_bits={pkt_cls.size_bits()})")
class TestFieldExtraction(unittest.TestCase):
"""Test that field values are extracted correctly."""
def test_valuinst(self):
reg = 0b11110_1_001_011 # wave=0x1E, flag=1, delta=1
pkt = VALUINST.from_raw(reg)
self.assertEqual(pkt.delta, 1)
self.assertEqual(pkt.flag, 1)
self.assertEqual(pkt.wave, 0x1E)
def test_vmemexec_enum(self):
reg = 0b11_00_1111 # src=3 (VMEM_ALT), delta=0
pkt = VMEMEXEC.from_raw(reg)
self.assertEqual(pkt.src, MemSrc.VMEM_ALT)
def test_aluexec_enum(self):
reg = 0b10_01_1110 # src=2 (VALU), delta=1
pkt = ALUEXEC.from_raw(reg)
self.assertEqual(pkt.src, AluSrc.VALU)
def test_waveend(self):
reg = (0x15 << 15) | (0x7 << 11) | (0x3 << 9) | (1 << 8) | 0b10101
pkt = WAVEEND.from_raw(reg)
self.assertEqual(pkt.flag7, 1)
self.assertEqual(pkt.simd, 3)
self.assertEqual(pkt.cu_lo, 7)
self.assertEqual(pkt.wave, 0x15)
self.assertEqual(pkt.cu, 0xF) # cu_lo | (flag7 << 3) = 7 | 8 = 15
def test_wavestart(self):
reg = (0x7F << 18) | (0x15 << 13) | (0x7 << 10) | (0x3 << 8) | (1 << 7) | 0b01100
pkt = WAVESTART.from_raw(reg)
self.assertEqual(pkt.flag7, 1)
self.assertEqual(pkt.simd, 3)
self.assertEqual(pkt.cu_lo, 7)
self.assertEqual(pkt.wave, 0x15)
self.assertEqual(pkt.id7, 0x7F)
self.assertEqual(pkt.cu, 0xF)
def test_inst_enum(self):
reg = (0x21 << 13) | (0x15 << 8) | (1 << 7) | (1 << 3) | 0b010
pkt = INST.from_raw(reg)
self.assertEqual(pkt.flag1, 1)
self.assertEqual(pkt.flag2, 1)
self.assertEqual(pkt.wave, 0x15)
self.assertEqual(pkt.op, InstOp.VMEM_LOAD)
def test_layout_header(self):
reg = (0b101 << 33) | (0b1010 << 28) | (0b111 << 15) | (0b11 << 13) | (0b101010 << 7) | 0b0010001
pkt = LAYOUT_HEADER.from_raw(reg)
self.assertEqual(pkt.layout, 0b101010)
self.assertEqual(pkt.simd, 0b11)
self.assertEqual(pkt.group, 0b111)
self.assertEqual(pkt.sel_a, 0b1010)
self.assertEqual(pkt.sel_b, 0b101)
def test_ts_delta_or_mark_modes(self):
# delta mode: bit9=0, bit8=0
pkt_delta = TS_DELTA_OR_MARK.from_raw(0b0000001) # just the encoding pattern
self.assertFalse(pkt_delta.is_marker)
# marker mode: bit9=1, bit8=0
pkt_marker = TS_DELTA_OR_MARK.from_raw(0b0000001 | (1 << 9)) # bit9=1, bit8=0
self.assertTrue(pkt_marker.is_marker)
# other mode: bit9=1, bit8=1 (not marker)
pkt_other = TS_DELTA_OR_MARK.from_raw(0b0000001 | (1 << 8) | (1 << 9))
self.assertFalse(pkt_other.is_marker)
def test_reg(self):
# REG fields: slot=bits[9:7], hi_byte=bits[15:8], subop=bits[31:16], val32=bits[63:32]
# Note: slot[2:1] overlaps with hi_byte[1:0], so we need to set them consistently
# hi_byte=0x55 means bits 8-15 = 0b01010101, so slot bits 8-9 = 0b01
# slot bit 7 = 1, so slot = 0b011 = 3
reg = (0xDEADBEEF << 32) | (0xCAFE << 16) | (0x55 << 8) | (1 << 7) | 0b1001
pkt = REG.from_raw(reg)
self.assertEqual(pkt.slot, 0b011) # bit7=1, bits 8-9 from hi_byte low 2 bits = 01
self.assertEqual(pkt.hi_byte, 0x55)
self.assertEqual(pkt.subop, 0xCAFE)
self.assertEqual(pkt.val32, 0xDEADBEEF)
class TestRoundtrip(unittest.TestCase):
"""Test encode/decode roundtrip."""
def test_simple_roundtrip(self):
"""Test encode/decode roundtrip preserves packet types."""
test_packets = [
LAYOUT_HEADER.from_raw(0x100),
WAVESTART.from_raw(0x0),
INST.from_raw(0x10),
INST.from_raw(0x10),
WAVEEND.from_raw(0x40),
]
encoded = encode(test_packets)
decoded = decode(encoded)
self.assertGreaterEqual(len(decoded), len(test_packets))
for i, (orig, dec) in enumerate(zip(test_packets, decoded)):
self.assertEqual(type(orig), type(dec), f"type mismatch at {i}")
if __name__ == "__main__":
unittest.main()

View file

@ -21,7 +21,8 @@ from tinygrad.runtime.support.memory import AddrSpace
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
SQTT = ContextVar("SQTT", abs(VIZ.value)>=2)
SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE = ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0)
SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE, SQTT_SIMD_SEL, SQTT_TOKEN_EXCLUDE = \
ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0), ContextVar("SQTT_SIMD_SEL", 0), ContextVar("SQTT_TOKEN_EXCLUDE", 0)
PMC = ContextVar("PMC", abs(VIZ.value)>=2)
EVENT_INDEX_PARTIAL_FLUSH = 4 # based on a comment in nvd.h
WAIT_REG_MEM_FUNCTION_EQ = 3 # ==
@ -252,17 +253,18 @@ class AMDComputeQueue(HWQueue):
else:
self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_SIZE, base_hi=buf0_hi, size=buf0s[se].size >> 12)
self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_BASE, base_lo=buf0_lo)
# NOTE: SQTT can only trace instructions on one simd per se, this selects first simd in first wgp in first sa.
# NOTE: SQTT can only trace instructions on one simd per se, this selects the simd in first wgp in first sa.
# For RGP to display instruction trace it has to see it on first SE. Howerver ACE/MEC/whatever does the dispatching starting with second se,
# and on amdgpu/non-AM it also does weird things with dispatch order inside se: around 7 times out of 10 it starts from the last cu, but
# sometimes not, especially if the kernel has more than one wavefront which means that kernels with small global size might get unlucky and
# be dispatched on something else and not be seen in instruction tracing tab. You can force the wavefronts of a kernel to be dispatched on the
# CUs you want to by disabling other CUs via bits in regCOMPUTE_STATIC_THREAD_MGMT_SE<x> and trace even kernels that only have one wavefront.
# Use SQTT_SIMD_SEL to select which SIMD to trace (0-3). Memory ops show different InstOp values (0x2x vs 0x5x) based on SIMD.
cs_wtype = (1 << 6) if self.dev.target >= (12,0,0) else self.soc.SQ_TT_WTYPE_INCLUDE_CS_BIT
self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=0, wgp_sel=0, sa_sel=0)
self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=SQTT_SIMD_SEL.value, wgp_sel=0, sa_sel=0)
reg_include = self.soc.SQ_TT_TOKEN_MASK_SQDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_SHDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_GFXUDEC_BIT | \
self.soc.SQ_TT_TOKEN_MASK_COMP_BIT | self.soc.SQ_TT_TOKEN_MASK_CONTEXT_BIT
token_exclude = (1 << self.soc.SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT) if self.dev.target < (12,0,0) else 0
token_exclude = SQTT_TOKEN_EXCLUDE.value | ((1 << self.soc.SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT) if self.dev.target < (12,0,0) else 0)
# disable instr tracing
if not (SQTT_ITRACE_SE_MASK.value >> se) & 0b1: