mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
83 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2f85319722 | ||
|
|
d9f0e9c40c | ||
|
|
3dcffbea25 | ||
|
|
41c5368266 | ||
|
|
4598a21f94 | ||
|
|
27084cd618 | ||
|
|
d2616e5daf | ||
|
|
3130c53f85 | ||
|
|
1c66e41383 | ||
|
|
93823b272c | ||
|
|
1c6147e9bf | ||
|
|
7f5656d236 | ||
|
|
5f55a61700 | ||
|
|
ed097df864 | ||
|
|
a83c97f17e | ||
|
|
c793076fb6 | ||
|
|
1f45601a97 | ||
|
|
14c4989f65 | ||
|
|
31b38640ac | ||
|
|
fe770e822c | ||
|
|
768231c065 | ||
|
|
4165594b30 | ||
|
|
c03b7b0da1 | ||
|
|
66249836c0 | ||
|
|
a0d6ed9914 | ||
|
|
99fcfc0e97 | ||
|
|
cf8bb15aef | ||
|
|
32dfc9b1d0 | ||
|
|
9803e389fe | ||
|
|
1f893b65cc | ||
|
|
35f5f05ad5 | ||
|
|
b9f08ad18a | ||
|
|
222ae38aa4 | ||
|
|
f0bf20d7b2 | ||
|
|
85ef097da6 | ||
|
|
0e240fb987 |
||
|
|
d2c1712e4c | ||
|
|
96b0ee0966 | ||
|
|
9b5c4bc698 | ||
|
|
6ea3586101 | ||
|
|
92cb8b6776 | ||
|
|
c416b20668 | ||
|
|
415b83ba18 | ||
|
|
8c7eacea59 | ||
|
|
81542699f8 | ||
|
|
79f55a5d5e | ||
|
|
37518fb236 | ||
|
|
672008ccab | ||
|
|
849af761a4 | ||
|
|
ab46b3d8d3 | ||
|
|
df20197bfb | ||
|
|
2b56c264d5 | ||
|
|
c7e5c2f996 | ||
|
|
659aa14043 | ||
|
|
21ffa1a86b | ||
|
|
29f3fb7af3 | ||
|
|
1edc7fc519 | ||
|
|
c9a3ac988c | ||
|
|
77d96acbe3 | ||
|
|
660ecf272b | ||
|
|
267bbb163e | ||
|
|
de29a49ea3 | ||
|
|
742e10a572 | ||
|
|
447fe8907b | ||
|
|
b0cfcec183 | ||
|
|
1726084b2a | ||
|
|
de069a4876 | ||
|
|
4573e91e61 | ||
|
|
8d43212bc6 | ||
|
|
a8bea4ec52 | ||
|
|
388514c5b1 | ||
|
|
729bb04d8c | ||
|
|
8f4de73141 | ||
|
|
a5959ef0f1 | ||
|
|
5ba06892c0 | ||
|
|
469efe313d | ||
|
|
e3b3cb163d | ||
|
|
3e32185faf | ||
|
|
5328913d2b | ||
|
|
9c49ec1cc1 | ||
|
|
000d4a125b | ||
|
|
63289902d8 | ||
|
|
b596f77e33 |
11 changed files with 5848 additions and 8 deletions
|
|
@ -2,12 +2,12 @@
|
|||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
import ctypes, functools
|
||||
from tinygrad.helpers import DEBUG, colored, ansilen
|
||||
from tinygrad.runtime.autogen import hsa
|
||||
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.pcode import compile_pseudocode
|
||||
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64, SrcEnum
|
||||
from extra.assembly.amd.pcode import Reg, compile_pseudocode
|
||||
from extra.assembly.amd.asm import detect_format, disasm
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
|
||||
from extra.assembly.amd.dsl import SrcEnum
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
|
||||
SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
|
||||
|
||||
|
|
@ -329,6 +329,294 @@ def exec_wmma(st: WaveState, inst, op: VOP3POp) -> None:
|
|||
else:
|
||||
for i in range(256): st.vgpr[i % 32][vdst + i//32] = _i32(mat_d[i])
|
||||
|
||||
# SQTT TRACING
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
WAVESTART_TO_INST_CYCLES = 32
|
||||
SNOP_EXTRA_DELAY_MIN, SNOP_EXTRA_DELAY_MAX = 11, 22 # s_nop(11-22) has +4 penalty
|
||||
SNOP_EXTRA_DELAY_CYCLES = 4
|
||||
|
||||
from extra.assembly.amd.sqtt import WAVESTART, WAVEEND, IMMEDIATE, VALUINST, ALUEXEC, AluSrc
|
||||
|
||||
def _get_src_vgprs(inst: Inst) -> list[int]:
|
||||
if isinstance(inst, VOP1): return [inst.src0 - 256] if inst.src0 >= 256 else []
|
||||
if isinstance(inst, VOP2): return ([inst.src0 - 256] if inst.src0 >= 256 else []) + [inst.vsrc1]
|
||||
if isinstance(inst, VOP3): return [s - 256 for s in [inst.src0, inst.src1, getattr(inst, 'src2', None)] if s is not None and s >= 256]
|
||||
return []
|
||||
|
||||
class SQTTState:
|
||||
"""SQTT tracing with cycle-accurate RDNA3 VALU pipeline model.
|
||||
|
||||
NOTE: This is a hardware-plausible model derived from observed SQTT timing patterns.
|
||||
The model should be verified by tests against real hardware traces, not by fitting
|
||||
formulas to expected outputs. If tests fail, the model needs to be understood and
|
||||
fixed, not hacked with magic constants.
|
||||
|
||||
Physical model:
|
||||
- alu[4]: 4-stage ALU pipeline, each slot holds dest_vgpr or None
|
||||
- in_flight: up to 12 in-flight instructions (issued but not yet completed)
|
||||
- issue_queue: instructions waiting to enter ALU (sources not ready)
|
||||
- fwd_slots: 4 forwarding slots, reserved at issue, freed when consumer forwards
|
||||
- completed: vgprs with results ready (exited ALU)
|
||||
|
||||
Forwarding model (4 slots):
|
||||
- Slot reserved at ISSUE time if available (len(fwd_slots) < 4)
|
||||
- Slot freed when a consumer uses the result for forwarding
|
||||
- Consumer can forward if: has a slot AND producer is completed
|
||||
- If no slot at issue, instruction uses regfile path (+4 cycle penalty)
|
||||
"""
|
||||
def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0):
|
||||
self.wave_id, self.simd, self.cu = wave_id, simd, cu
|
||||
self.cycle = 0
|
||||
self.packets = []
|
||||
|
||||
# 4-stage ALU pipeline: each slot holds dest_vgpr or None
|
||||
self.alu = [None, None, None, None]
|
||||
|
||||
# In-flight instructions: max 12 at a time, each is (dest_vgpr, srcs, has_fwd_slot)
|
||||
self.in_flight: list[tuple[int, list[int], bool]] = []
|
||||
|
||||
# Issue queue: list of (dest_vgpr, srcs, ready_at, has_fwd_slot, was_warm) waiting for deps
|
||||
# ready_at: cycle when this instruction can enter ALU (0 = no restriction)
|
||||
# has_fwd_slot: True if this instruction reserved a forwarding slot at issue time
|
||||
# was_warm: True if forwarding path was warm when this instruction was issued
|
||||
self.issue_queue: list[tuple[int, list[int], int, bool, bool]] = []
|
||||
|
||||
# 4 forwarding slots: consumer adds producer at issue, freed when consumer forwards
|
||||
self.fwd_slots: list[int] = [] # producer vgprs reserved for forwarding
|
||||
|
||||
# VGPRs that had a dependent try to add them to fwd_slots (successful or not)
|
||||
self.had_dependent: set[int] = set()
|
||||
|
||||
# VGPRs that were issued after forwarding chain broke (can't forward)
|
||||
self.fwd_chain_broken: set[int] = set()
|
||||
|
||||
# Set of completed vgprs (results ready, exited ALU)
|
||||
self.completed: set[int] = set()
|
||||
|
||||
# Cold start: first forwarding use has +1 cycle penalty
|
||||
self.forward_warm = False
|
||||
self.cold_used = False # True if cold start penalty was applied
|
||||
|
||||
def emit(self, pkt_class, **kwargs):
|
||||
self.packets.append(pkt_class(_time=self.cycle, **kwargs))
|
||||
|
||||
def _fmt_alu(self) -> str:
|
||||
# Fixed width: each slot 3 chars, total ALU[xxx,xxx,xxx,xxx] = 20 chars
|
||||
slots = [f'v{v}' if v is not None else '-' for v in self.alu]
|
||||
return 'ALU[' + ','.join(f'{s:>3}' for s in slots) + ']'
|
||||
|
||||
def _fmt_fwd(self) -> str:
|
||||
items = [f'v{v}' for v in self.fwd_slots]
|
||||
content = 'FWD[' + ','.join(items) + ']' if items else 'FWD[]'
|
||||
padded = f'{content:<24}'
|
||||
return colored(padded, 'yellow') if items else padded
|
||||
|
||||
def _fmt_iq(self) -> str:
|
||||
def fmt_item(d, r, fwd):
|
||||
s = f'v{d}'
|
||||
if r != 0: s += f'@{abs(r)}'
|
||||
if not fwd: s += 'R'
|
||||
return s
|
||||
items = [fmt_item(d, r, fwd) for d, _, r, fwd, _ in self.issue_queue]
|
||||
return 'IQ[' + ','.join(items) + ']' if items else 'IQ[]'
|
||||
|
||||
def _debug_line(self, events: list[str] | None = None):
|
||||
if DEBUG < 3: return
|
||||
# Skip empty cycles (nothing in ALU, no events, no IQ)
|
||||
has_alu = any(s is not None for s in self.alu)
|
||||
if not has_alu and not events and not self.issue_queue: return
|
||||
cycle = colored(f'C{self.cycle:>3}:', 'cyan')
|
||||
alu = self._fmt_alu()
|
||||
fwd = self._fmt_fwd()
|
||||
iq = f'{self._fmt_iq():<28}'
|
||||
ev_str = ' '.join(events) if events else ''
|
||||
ev_padded = f'{ev_str:<20}' if ev_str else ' ' * 20
|
||||
print(f"{cycle} {alu} {fwd} {iq} {ev_padded}")
|
||||
|
||||
def _can_issue(self) -> bool:
|
||||
return len(self.in_flight) < 12
|
||||
|
||||
def _has_pending_write(self, vgpr: int) -> bool:
|
||||
"""Check if there's a pending write to this VGPR (in ALU, in-flight, or issue queue)."""
|
||||
if any(slot == vgpr for slot in self.alu if slot is not None): return True
|
||||
if any(d == vgpr for d, _, _ in self.in_flight): return True
|
||||
if any(d == vgpr for d, _, _, _, _ in self.issue_queue): return True
|
||||
return False
|
||||
|
||||
def _all_srcs_ready(self, srcs: list[int]) -> bool:
|
||||
"""Returns True if all sources are ready (completed or no pending write)."""
|
||||
for src in srcs:
|
||||
if src in self.completed: continue
|
||||
if not self._has_pending_write(src): continue # initial value
|
||||
return False
|
||||
return True
|
||||
|
||||
def tick(self):
|
||||
self.cycle += 1
|
||||
if self.cycle > 10000: raise RuntimeError("cycle limit exceeded")
|
||||
events = []
|
||||
|
||||
# 1. ALU[3] exits - capture but don't add to completed yet
|
||||
exiting = self.alu[3]
|
||||
if exiting is not None:
|
||||
self.emit(ALUEXEC, src=AluSrc.VALU)
|
||||
events.append(colored(f"EXEC v{exiting}", 'red'))
|
||||
|
||||
# 2. Slide ALU pipeline
|
||||
self.alu[3] = self.alu[2]
|
||||
self.alu[2] = self.alu[1]
|
||||
self.alu[1] = self.alu[0]
|
||||
self.alu[0] = None
|
||||
|
||||
# 3. Try to promote from issue_queue to ALU[0] (before adding exiting to completed)
|
||||
if self.alu[0] is None and self.issue_queue:
|
||||
for i, (dest, srcs, ready_at, has_fwd_slot, was_warm) in enumerate(self.issue_queue):
|
||||
# Check if instruction has a minimum ready cycle
|
||||
if ready_at > 0 and self.cycle < ready_at:
|
||||
continue
|
||||
# Check if sources are ready
|
||||
ready = self._all_srcs_ready(srcs)
|
||||
has_deps = len(srcs) > 0
|
||||
if not ready:
|
||||
continue
|
||||
# Cold start penalty: first dependent instruction has +1 cycle delay (delta=6 vs delta=5)
|
||||
# Only applies if forwarding path wasn't warm when this instruction was issued
|
||||
if has_deps and not was_warm and not self.cold_used:
|
||||
self.cold_used = True
|
||||
self.issue_queue[i] = (dest, srcs, self.cycle + 1, has_fwd_slot, was_warm)
|
||||
continue
|
||||
# Forwarding: consumer can forward if:
|
||||
# 1. Not in fwd_chain_broken (chain must be intact), AND
|
||||
# 2. Producer has a slot (source is in fwd_slots), AND
|
||||
# 3. Either activated by dependent OR successfully added producer at issue
|
||||
# Note: if issued cold with no slot, activation only counts if the activator also has a dependent
|
||||
chain_intact = dest not in self.fwd_chain_broken
|
||||
producer_has_slot = has_deps and any(src in self.fwd_slots for src in srcs)
|
||||
# Check activation validity
|
||||
if dest in self.had_dependent:
|
||||
if was_warm or has_fwd_slot:
|
||||
activated_by_dependent = True
|
||||
else:
|
||||
# Cold + no slot: activation only counts if activator itself has a dependent
|
||||
# This handles the chain_6 vs chain_7 difference (chain_7 has v6 which activates v5)
|
||||
activated_by_dependent = (dest + 1) in self.had_dependent # activator is dest+1 in a chain
|
||||
else:
|
||||
activated_by_dependent = False
|
||||
can_forward = chain_intact and producer_has_slot and (activated_by_dependent or has_fwd_slot)
|
||||
# Regfile path: has dependencies but can't forward
|
||||
must_use_regfile = has_deps and not can_forward
|
||||
# Regfile penalty: add +4 cycles latency (only apply once)
|
||||
if must_use_regfile and ready_at == 0:
|
||||
self.issue_queue[i] = (dest, srcs, self.cycle + 4, has_fwd_slot, was_warm)
|
||||
continue
|
||||
# Enter ALU
|
||||
self.alu[0] = dest
|
||||
self.issue_queue.pop(i)
|
||||
# Free producer's forwarding slot when consumer dispatches (regardless of fwd/rf)
|
||||
for src in srcs:
|
||||
if src in self.fwd_slots:
|
||||
self.fwd_slots.remove(src)
|
||||
break
|
||||
events.append(colored(f"v{dest}->ALU" + ("(fwd)" if can_forward else "(rf)" if must_use_regfile else ""), 'green'))
|
||||
break
|
||||
|
||||
# 4. Now add exiting instruction to completed (after promotion decision)
|
||||
if exiting is not None:
|
||||
self.completed.add(exiting)
|
||||
# Remove from in_flight - any VALU completing warms up the forward path
|
||||
for idx, (d, _, _) in enumerate(self.in_flight):
|
||||
if d == exiting:
|
||||
self.forward_warm = True
|
||||
self.in_flight.pop(idx)
|
||||
break
|
||||
|
||||
self._debug_line(events)
|
||||
|
||||
def _pipeline_empty(self) -> bool:
|
||||
if any(s is not None for s in self.alu): return False
|
||||
if self.issue_queue: return False
|
||||
if self.in_flight: return False
|
||||
return True
|
||||
|
||||
def process_instruction(self, inst: Inst):
|
||||
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_DELAY_ALU:
|
||||
# TODO: implement s_delay_alu properly
|
||||
return
|
||||
|
||||
elif isinstance(inst, SOPP) and inst.op == SOPPOp.S_NOP:
|
||||
# s_nop(N) delays N+1 cycles, plus extra penalty for s_nop(11-22)
|
||||
cycles = inst.simm16 + 1
|
||||
if SNOP_EXTRA_DELAY_MIN <= inst.simm16 <= SNOP_EXTRA_DELAY_MAX:
|
||||
cycles += SNOP_EXTRA_DELAY_CYCLES
|
||||
if DEBUG >= 3:
|
||||
cycle = colored(f'C{self.cycle:>3}:', 'cyan')
|
||||
# 20 (ALU) + 1 + 24 (FWD) + 1 + 28 (IQ) + 1 + 20 (events) = 95 padding after cycle
|
||||
print(f"{cycle} {' ' * 95} {disasm(inst)}")
|
||||
for _ in range(cycles): self.tick()
|
||||
self.emit(IMMEDIATE, wave=self.wave_id)
|
||||
|
||||
elif isinstance(inst, SOPP) and inst.op == SOPPOp.S_ENDPGM:
|
||||
# Drain pipeline before ending
|
||||
while not self._pipeline_empty(): self.tick()
|
||||
self.emit(WAVEEND, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3)
|
||||
|
||||
elif isinstance(inst, (VOP1, VOP2, VOP3)):
|
||||
# Check for issue stall (no free in-flight slots)
|
||||
while not self._can_issue():
|
||||
self.tick()
|
||||
|
||||
# Issue: add to in_flight and issue_queue
|
||||
srcs = _get_src_vgprs(inst)
|
||||
dest = inst.vdst
|
||||
# Clear stale state for this dest (WAW hazard)
|
||||
self.completed.discard(dest)
|
||||
if dest in self.fwd_slots: self.fwd_slots.remove(dest)
|
||||
|
||||
# Consumer adds producer to fwd_slots (if room and has dependency)
|
||||
# If producer is in fwd_chain_broken, or we can't add, the chain breaks for this instruction too
|
||||
has_fwd_slot = False
|
||||
if srcs:
|
||||
producer = srcs[0]
|
||||
self.had_dependent.add(producer) # record that producer has a dependent
|
||||
# Check if producer's forwarding chain is already broken
|
||||
if producer in self.fwd_chain_broken:
|
||||
# Chain is broken, this instruction also can't forward
|
||||
self.fwd_chain_broken.add(dest)
|
||||
elif len(self.fwd_slots) >= 4:
|
||||
# Can't add producer, chain breaks
|
||||
self.fwd_chain_broken.add(dest)
|
||||
else:
|
||||
# Can add producer
|
||||
if producer not in self.fwd_slots:
|
||||
self.fwd_slots.append(producer)
|
||||
has_fwd_slot = len(self.fwd_slots) < 4
|
||||
|
||||
# Record if forwarding path was warm at issue time
|
||||
was_warm = self.forward_warm
|
||||
|
||||
self.in_flight.append((dest, srcs, has_fwd_slot))
|
||||
self.issue_queue.append((dest, srcs, 0, has_fwd_slot, was_warm))
|
||||
self.emit(VALUINST, wave=self.wave_id)
|
||||
|
||||
if DEBUG >= 3:
|
||||
cycle = colored(f'C{self.cycle:>3}:', 'cyan')
|
||||
slot_info = "" if has_fwd_slot else colored(" NO_SLOT", 'red')
|
||||
issue = colored(f'ISSUE v{dest}', 'magenta') + slot_info
|
||||
padding = 95 - ansilen(issue)
|
||||
print(f"{cycle} {issue}{' ' * padding} {disasm(inst)}")
|
||||
|
||||
# One cycle per instruction issued, then try to enter ALU
|
||||
self.tick()
|
||||
return
|
||||
|
||||
# One cycle per instruction issued (for non-VALU)
|
||||
self.tick()
|
||||
|
||||
def emit_wavestart(self):
|
||||
self.emit(WAVESTART, wave=self.wave_id, simd=self.simd, cu_lo=self.cu & 0x7, flag7=self.cu >> 3)
|
||||
for _ in range(WAVESTART_TO_INST_CYCLES): self.tick()
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PROGRAM DECODE
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
|
|
|||
|
|
@ -119,6 +119,24 @@ class PacketType:
|
|||
cls._extract_info = [(name, bf.lo, bf.mask(), cls._field_types.get(name)) for name, bf in cls._fields.items()]
|
||||
cls._size_nibbles = ((max((f.hi for f in cls._fields.values()), default=0) + 4) // 4)
|
||||
|
||||
def __init__(self, _time: int = 0, **kwargs):
|
||||
"""Construct packet from named fields (like assembly instructions)."""
|
||||
raw = 0
|
||||
if self._encoding:
|
||||
bf, pattern = self._encoding
|
||||
raw |= pattern << bf.lo
|
||||
for name, bf in self._fields.items():
|
||||
val = kwargs.get(name, 0)
|
||||
if isinstance(val, IntEnum): val = val.value
|
||||
raw |= (val & bf.mask()) << bf.lo
|
||||
self._raw, self._time, self._values = raw, _time, {}
|
||||
for name, lo, mask, enum_type in self._extract_info:
|
||||
val = (raw >> lo) & mask
|
||||
if enum_type is not None:
|
||||
try: val = enum_type(val)
|
||||
except ValueError: pass
|
||||
self._values[name] = val
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, raw: int, time: int = 0):
|
||||
inst = object.__new__(cls)
|
||||
|
|
@ -305,6 +323,8 @@ PACKET_TYPES: list[type[PacketType]] = [
|
|||
NOP,
|
||||
]
|
||||
|
||||
PACKET_BY_NAME: dict[str, type[PacketType]] = {cls.__name__: cls for cls in PACKET_TYPES}
|
||||
|
||||
def _build_state_table() -> tuple[bytes, dict[int, type[PacketType]]]:
|
||||
table = [len(PACKET_TYPES) - 1] * 256 # default to NOP
|
||||
opcode_to_class: dict[int, type[PacketType]] = {i: cls for i, cls in enumerate(PACKET_TYPES)}
|
||||
|
|
@ -321,6 +341,11 @@ def _build_state_table() -> tuple[bytes, dict[int, type[PacketType]]]:
|
|||
|
||||
STATE_TO_OPCODE, OPCODE_TO_CLASS = _build_state_table()
|
||||
|
||||
OPCODE_TO_BYTES: dict[int, list[int]] = {}
|
||||
for _byte_val, _opcode in enumerate(STATE_TO_OPCODE):
|
||||
if _opcode not in OPCODE_TO_BYTES: OPCODE_TO_BYTES[_opcode] = []
|
||||
OPCODE_TO_BYTES[_opcode].append(_byte_val)
|
||||
|
||||
# Precompute special case opcodes
|
||||
_TS_DELTA_OR_MARK_OPCODE = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is TS_DELTA_OR_MARK)
|
||||
_TS_DELTA_SHORT_OPCODE = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is TS_DELTA_SHORT)
|
||||
|
|
@ -379,3 +404,47 @@ def decode(data: bytes) -> list[PacketType]:
|
|||
packets_append(pkt_cls.from_raw(reg, time))
|
||||
|
||||
return packets
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# ENCODER
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def encode(packets: list[PacketType]) -> bytes:
|
||||
"""Encode a list of packet instances into raw SQTT blob."""
|
||||
if not packets: return b''
|
||||
|
||||
read_lengths = [16]
|
||||
for p in packets[:-1]:
|
||||
read_lengths.append(type(p)._size_nibbles)
|
||||
|
||||
total_nibbles = sum(read_lengths)
|
||||
bits_arr = [0] * (total_nibbles * 4)
|
||||
|
||||
cumulative = 0
|
||||
for i, p in enumerate(packets):
|
||||
cumulative += read_lengths[i]
|
||||
pkt_cls = type(p)
|
||||
opcode = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is pkt_cls)
|
||||
|
||||
byte_vals = OPCODE_TO_BYTES.get(opcode)
|
||||
if not byte_vals: raise ValueError(f"No encoding for {pkt_cls.__name__}")
|
||||
opcode_byte = byte_vals[0]
|
||||
|
||||
delta_field = getattr(pkt_cls, 'delta', None)
|
||||
if delta_field is not None and delta_field.hi < 8:
|
||||
delta = p._values.get('delta', 0)
|
||||
if isinstance(delta, IntEnum): delta = delta.value
|
||||
if pkt_cls is TS_DELTA_SHORT: delta = max(0, delta - 8)
|
||||
delta = delta & delta_field.mask()
|
||||
opcode_byte = (opcode_byte & ~(delta_field.mask() << delta_field.lo)) | (delta << delta_field.lo)
|
||||
|
||||
opcode_nibble_pos = max(0, cumulative - 16)
|
||||
opcode_bit_pos = opcode_nibble_pos * 4
|
||||
|
||||
for b in range(8):
|
||||
if opcode_bit_pos + b < len(bits_arr):
|
||||
bits_arr[opcode_bit_pos + b] = (opcode_byte >> b) & 1
|
||||
|
||||
nibbles = [sum(bits_arr[i + j] << j for j in range(4) if i + j < len(bits_arr)) for i in range(0, len(bits_arr), 4)]
|
||||
while len(nibbles) % 2: nibbles.append(0)
|
||||
return bytes(nibbles[i] | (nibbles[i + 1] << 4) for i in range(0, len(nibbles), 2))
|
||||
|
|
|
|||
855
extra/assembly/amd/test/discover_instops.py
Normal file
855
extra/assembly/amd/test/discover_instops.py
Normal file
|
|
@ -0,0 +1,855 @@
|
|||
#!/usr/bin/env python3
|
||||
"""SQTT InstOp discovery tool - finds instruction opcodes by running different instructions.
|
||||
|
||||
Requires profiling enabled:
|
||||
echo 'profile_standard' | sudo tee /sys/class/drm/card1/device/power_dpm_force_performance_level
|
||||
|
||||
Run with: DEBUG=1 python extra/assembly/amd/test/discover_instops.py
|
||||
For full traces: DEBUG=2 python extra/assembly/amd/test/discover_instops.py
|
||||
"""
|
||||
import os
|
||||
os.environ["SQTT"] = "1"
|
||||
os.environ["PROFILE"] = "1"
|
||||
os.environ["SQTT_LIMIT_SE"] = "2" # Force work to traced SE only
|
||||
os.environ["SQTT_TOKEN_EXCLUDE"] = "3784" # Exclude WAVERDY, REG, EVENT, UTILCTR, WAVEALLOC, PERF
|
||||
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
from tinygrad.runtime.ops_amd import SQTT_SIMD_SEL
|
||||
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (
|
||||
# VALU - basic (these are safe, just register ops)
|
||||
v_mov_b32_e32, v_add_f32_e32, v_mul_f32_e32,
|
||||
v_and_b32_e32, v_or_b32_e32, v_xor_b32_e32,
|
||||
v_lshlrev_b32_e32, v_lshrrev_b32_e32,
|
||||
# VALU - transcendental
|
||||
v_exp_f32_e32, v_log_f32_e32, v_rcp_f32_e32, v_sqrt_f32_e32,
|
||||
v_sin_f32_e32, v_cos_f32_e32,
|
||||
# VALU - 64-bit
|
||||
v_lshlrev_b64, v_lshrrev_b64, v_ashrrev_i64,
|
||||
v_add_f64, v_mul_f64, v_max_f64, v_min_f64,
|
||||
v_fma_f64,
|
||||
# VALU - 64-bit transcendental
|
||||
v_rcp_f64_e32, v_rsq_f64_e32, v_sqrt_f64_e32,
|
||||
v_trunc_f64_e32, v_ceil_f64_e32, v_floor_f64_e32, v_fract_f64_e32,
|
||||
v_frexp_exp_i32_f64_e32, v_frexp_mant_f64_e32,
|
||||
# VALU - div helpers
|
||||
v_div_fixup_f32, v_div_fixup_f64, v_div_fmas_f32, v_div_fmas_f64, v_div_scale_f32,
|
||||
# VALU - MAD64
|
||||
v_mad_u64_u32, v_mad_i64_i32,
|
||||
# VALU - compare (writes to VCC, safe)
|
||||
v_cmp_eq_u32_e32,
|
||||
# VALU - cmpx (modifies EXEC) - various types
|
||||
v_cmpx_eq_u32_e32, v_cmpx_lt_u32_e32, v_cmpx_gt_u32_e32,
|
||||
v_cmpx_eq_f32_e32, v_cmpx_lt_f32_e32,
|
||||
v_cmpx_eq_i32_e32,
|
||||
v_cmpx_class_f32_e32,
|
||||
# VALU - readlane/writelane
|
||||
v_readlane_b32, v_writelane_b32,
|
||||
v_readfirstlane_b32_e32,
|
||||
# SALU - basic (safe, just register ops)
|
||||
s_mov_b32, s_add_u32, s_and_b32, s_or_b32,
|
||||
s_lshl_b32, s_lshr_b32,
|
||||
s_nop, s_endpgm, s_waitcnt,
|
||||
# SALU - float
|
||||
s_ceil_f32, s_floor_f32, s_trunc_f32,
|
||||
# SALU - branch (safe if offset is 0 = next instruction)
|
||||
s_branch, s_cbranch_scc0, s_cbranch_execz, s_cbranch_execnz,
|
||||
# SALU - message
|
||||
s_sendmsg,
|
||||
# SALU - bit manipulation
|
||||
s_brev_b32, s_bcnt1_i32_b32, s_ctz_i32_b32, s_clz_i32_u32,
|
||||
# SALU - saveexec (modifies EXEC)
|
||||
s_and_saveexec_b32, s_or_saveexec_b32, s_xor_saveexec_b32,
|
||||
# SMEM - scalar memory (load from kernarg pointer in s[0:1])
|
||||
s_load_b32, s_load_b64,
|
||||
# GLOBAL - global memory (load/store) - various widths
|
||||
global_load_u8, global_load_u16, global_load_b32, global_load_b64, global_load_b96, global_load_b128,
|
||||
global_store_b8, global_store_b16, global_store_b32, global_store_b64, global_store_b96, global_store_b128,
|
||||
# GLOBAL - atomics
|
||||
global_atomic_add_u32, global_atomic_add_u64,
|
||||
# FLAT - flat memory access
|
||||
flat_load_b32, flat_load_b64, flat_load_b96, flat_load_b128,
|
||||
flat_store_b8, flat_store_b16, flat_store_b32, flat_store_b64, flat_store_b96, flat_store_b128,
|
||||
# LDS - local data share - various widths
|
||||
ds_load_b32, ds_load_b64, ds_load_b128,
|
||||
ds_store_b32, ds_store_b64, ds_store_b128,
|
||||
# LDS - atomics
|
||||
ds_add_u32, ds_max_u32, ds_min_u32,
|
||||
# VOP3P - packed
|
||||
v_pk_add_f16, v_pk_mul_f16, v_pk_fma_f16, v_pk_add_i16,
|
||||
# VOP3 - misc
|
||||
v_bfe_u32, v_bfi_b32, v_alignbit_b32, v_fma_f32,
|
||||
v_add3_u32, v_xad_u32, v_lshl_or_b32, v_add_nc_u32_e32,
|
||||
# VOP3 - carry-out
|
||||
v_add_co_u32, v_add_co_ci_u32_e32,
|
||||
# VOPD - dual issue
|
||||
v_dual_add_f32, v_dual_mul_f32,
|
||||
# VOP2 - fmac
|
||||
v_fmac_f32_e32,
|
||||
# DOT
|
||||
v_dot2_f16_f16,
|
||||
# WMMA
|
||||
v_wmma_f32_16x16x16_f16, v_wmma_f16_16x16x16_f16, v_wmma_i32_16x16x16_iu8,
|
||||
# Permlane ops
|
||||
v_permlane64_b32_e32, v_permlane16_b32, v_permlanex16_b32,
|
||||
# Interpolation
|
||||
v_interp_p10_f32, v_interp_p2_f32,
|
||||
# Barrier
|
||||
s_barrier,
|
||||
# SrcEnum for NULL soffset
|
||||
SrcEnum,
|
||||
)
|
||||
from extra.assembly.amd.dsl import v, s
|
||||
from extra.assembly.amd.sqtt import InstOp, INST, WAVESTART, WAVEEND, ALUEXEC, VMEMEXEC
|
||||
|
||||
from extra.assembly.amd.test.test_sqtt_hw import (
|
||||
run_asm_sqtt, decode_all_blobs, get_inst_ops, print_blobs, get_wave_packets, format_packet, PACKET_COLORS, count_valuinst
|
||||
)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# INSTRUCTION TEST CASES - only safe instructions that don't access memory
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# Helper: load buffer address from kernarg (s[0:1] -> s[2:3])
|
||||
# The runtime passes kernarg pointer in s[0:1], kernarg contains buffer address
|
||||
def _load_buf_addr():
|
||||
return [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL), # load buf addr from kernarg
|
||||
s_waitcnt(lgkmcnt=0), # wait for SMEM load
|
||||
]
|
||||
|
||||
INSTRUCTION_TESTS: dict[str, tuple[str, list]] = {
|
||||
# SALU (0x0) - scalar ALU, just register operations
|
||||
"SALU_mov": ("s_mov_b32", [s_mov_b32(s[4], 0), s_mov_b32(s[5], 1)]),
|
||||
"SALU_add": ("s_add_u32", [s_mov_b32(s[4], 1), s_mov_b32(s[5], 2), s_add_u32(s[6], s[4], s[5])]),
|
||||
"SALU_logic": ("s_and/or", [s_and_b32(s[6], s[4], s[5]), s_or_b32(s[7], s[4], s[5])]),
|
||||
"SALU_shift": ("s_lshl/lshr", [s_lshl_b32(s[6], s[4], 1), s_lshr_b32(s[7], s[4], 1)]),
|
||||
"SALU_nop": ("s_nop", [s_nop(0)]),
|
||||
|
||||
# JUMP (0x3) - branch taken
|
||||
"JUMP_branch": ("s_branch", [s_branch(0)]),
|
||||
"JUMP_cbranch_execnz": ("s_cbranch_execnz", [s_cbranch_execnz(0)]), # EXEC != 0, branch taken
|
||||
|
||||
# JUMP_NO (0x4) - branch not taken
|
||||
"JUMP_NO_cbranch_execz": ("s_cbranch_execz", [s_cbranch_execz(0)]), # EXEC != 0, branch not taken
|
||||
|
||||
# VALU (0xb) - vector ALU, just register operations
|
||||
"VALU_mov": ("v_mov_b32", [v_mov_b32_e32(v[0], 0), v_mov_b32_e32(v[1], 1.0)]),
|
||||
"VALU_add": ("v_add_f32", [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[1], 2.0), v_add_f32_e32(v[2], v[0], v[1])]),
|
||||
"VALU_mul": ("v_mul_f32", [v_mul_f32_e32(v[2], v[0], v[1])]),
|
||||
"VALU_logic": ("v_and/or/xor", [v_and_b32_e32(v[2], v[0], v[1]), v_or_b32_e32(v[3], v[0], v[1]), v_xor_b32_e32(v[4], v[0], v[1])]),
|
||||
"VALU_shift": ("v_lshl/lshr", [v_lshlrev_b32_e32(v[2], 1, v[0]), v_lshrrev_b32_e32(v[3], 1, v[0])]),
|
||||
|
||||
# VALU transcendental - still just register ops
|
||||
"VALU_exp": ("v_exp_f32", [v_mov_b32_e32(v[0], 1.0), v_exp_f32_e32(v[1], v[0])]),
|
||||
"VALU_log": ("v_log_f32", [v_mov_b32_e32(v[0], 1.0), v_log_f32_e32(v[1], v[0])]),
|
||||
"VALU_rcp": ("v_rcp_f32", [v_mov_b32_e32(v[0], 1.0), v_rcp_f32_e32(v[1], v[0])]),
|
||||
"VALU_sqrt": ("v_sqrt_f32", [v_mov_b32_e32(v[0], 1.0), v_sqrt_f32_e32(v[1], v[0])]),
|
||||
|
||||
# VALU 64-bit shift (0xd)
|
||||
"VALU64_lshl": ("v_lshlrev_b64", [v_lshlrev_b64(v[0:1], 1, v[2:3])]),
|
||||
"VALU64_lshr": ("v_lshrrev_b64", [v_lshrrev_b64(v[0:1], 1, v[2:3])]),
|
||||
"VALU64_ashr": ("v_ashrrev_i64", [v_ashrrev_i64(v[0:1], 1, v[2:3])]),
|
||||
|
||||
# VALU 64-bit arithmetic
|
||||
"VALU64_add": ("v_add_f64", [v_add_f64(v[0:1], v[2:3], v[4:5])]),
|
||||
"VALU64_mul": ("v_mul_f64", [v_mul_f64(v[0:1], v[2:3], v[4:5])]),
|
||||
"VALU64_max": ("v_max_f64", [v_max_f64(v[0:1], v[2:3], v[4:5])]),
|
||||
"VALU64_min": ("v_min_f64", [v_min_f64(v[0:1], v[2:3], v[4:5])]),
|
||||
"VALU64_fma": ("v_fma_f64", [v_fma_f64(v[0:1], v[2:3], v[4:5], v[6:7])]),
|
||||
|
||||
# VALU 64-bit transcendental
|
||||
"VALU64_rcp": ("v_rcp_f64", [v_rcp_f64_e32(v[0:1], v[2:3])]),
|
||||
"VALU64_rsq": ("v_rsq_f64", [v_rsq_f64_e32(v[0:1], v[2:3])]),
|
||||
"VALU64_sqrt": ("v_sqrt_f64", [v_sqrt_f64_e32(v[0:1], v[2:3])]),
|
||||
|
||||
# VALU 64-bit rounding
|
||||
"VALU64_trunc": ("v_trunc_f64", [v_trunc_f64_e32(v[0:1], v[2:3])]),
|
||||
"VALU64_ceil": ("v_ceil_f64", [v_ceil_f64_e32(v[0:1], v[2:3])]),
|
||||
"VALU64_floor": ("v_floor_f64", [v_floor_f64_e32(v[0:1], v[2:3])]),
|
||||
"VALU64_fract": ("v_fract_f64", [v_fract_f64_e32(v[0:1], v[2:3])]),
|
||||
|
||||
# VALU 64-bit frexp
|
||||
"VALU64_frexp_exp": ("v_frexp_exp_i32_f64", [v_frexp_exp_i32_f64_e32(v[0], v[2:3])]),
|
||||
"VALU64_frexp_mant": ("v_frexp_mant_f64", [v_frexp_mant_f64_e32(v[0:1], v[2:3])]),
|
||||
|
||||
# VALU 64-bit div helpers
|
||||
"VALU64_div_fixup": ("v_div_fixup_f64", [v_div_fixup_f64(v[0:1], v[2:3], v[4:5], v[6:7])]),
|
||||
"VALU64_div_fmas": ("v_div_fmas_f64", [v_div_fmas_f64(v[0:1], v[2:3], v[4:5], v[6:7])]),
|
||||
|
||||
# VALU 32-bit div helpers
|
||||
"VALU_div_fixup": ("v_div_fixup_f32", [v_div_fixup_f32(v[0], v[1], v[2], v[3])]),
|
||||
"VALU_div_fmas": ("v_div_fmas_f32", [v_div_fmas_f32(v[0], v[1], v[2], v[3])]),
|
||||
"VALU_div_scale": ("v_div_scale_f32", [v_div_scale_f32(v[0], SrcEnum.VCC_LO, v[1], v[2], v[3])]),
|
||||
|
||||
# VALU MAD64 (0xe)
|
||||
"VALU_mad64u": ("v_mad_u64_u32", [
|
||||
v_mov_b32_e32(v[2], 2),
|
||||
v_mov_b32_e32(v[3], 3),
|
||||
v_mov_b32_e32(v[4], 0),
|
||||
v_mov_b32_e32(v[5], 0),
|
||||
v_mad_u64_u32(v[0:1], SrcEnum.NULL, v[2], v[3], v[4:5]),
|
||||
]),
|
||||
"VALU_mad64i": ("v_mad_i64_i32", [
|
||||
v_mov_b32_e32(v[2], 2),
|
||||
v_mov_b32_e32(v[3], 3),
|
||||
v_mov_b32_e32(v[4], 0),
|
||||
v_mov_b32_e32(v[5], 0),
|
||||
v_mad_i64_i32(v[0:1], SrcEnum.NULL, v[2], v[3], v[4:5]),
|
||||
]),
|
||||
|
||||
# VALU compare - writes to VCC
|
||||
"VALU_cmp": ("v_cmp_eq_u32", [v_cmp_eq_u32_e32(v[0], v[1])]),
|
||||
|
||||
# VALU CMPX (0x73) - modifies EXEC
|
||||
"VALU_cmpx_eq_u32": ("v_cmpx_eq_u32", [v_cmpx_eq_u32_e32(v[0], v[1])]),
|
||||
|
||||
# SALU saveexec (0x72) - modifies EXEC safely by ANDing with all-ones mask
|
||||
"SALU_saveexec": ("s_and_saveexec_b32", [
|
||||
s_mov_b32(s[5], 0xFFFFFFFF), # all lanes mask
|
||||
s_and_saveexec_b32(s[4], s[5]), # EXEC = EXEC & 0xFFFFFFFF = EXEC (unchanged)
|
||||
]),
|
||||
|
||||
# SALU float ops
|
||||
"SALU_ceil": ("s_ceil_f32", [s_ceil_f32(s[4], s[5])]),
|
||||
"SALU_floor": ("s_floor_f32", [s_floor_f32(s[4], s[5])]),
|
||||
"SALU_trunc": ("s_trunc_f32", [s_trunc_f32(s[4], s[5])]),
|
||||
|
||||
# SALU bit ops
|
||||
"SALU_brev": ("s_brev_b32", [s_brev_b32(s[4], s[5])]),
|
||||
"SALU_bcnt1": ("s_bcnt1_i32_b32", [s_bcnt1_i32_b32(s[4], s[5])]),
|
||||
"SALU_ctz": ("s_ctz_i32_b32", [s_ctz_i32_b32(s[4], s[5])]),
|
||||
"SALU_clz": ("s_clz_i32_u32", [s_clz_i32_u32(s[4], s[5])]),
|
||||
|
||||
# VALU sin/cos
|
||||
"VALU_sin": ("v_sin_f32", [v_sin_f32_e32(v[0], v[1])]),
|
||||
"VALU_cos": ("v_cos_f32", [v_cos_f32_e32(v[0], v[1])]),
|
||||
|
||||
# VOP3P - packed operations
|
||||
"VALU_pk_add_f16": ("v_pk_add_f16", [v_pk_add_f16(v[0], v[1], v[2])]),
|
||||
"VALU_pk_mul_f16": ("v_pk_mul_f16", [v_pk_mul_f16(v[0], v[1], v[2])]),
|
||||
"VALU_pk_fma_f16": ("v_pk_fma_f16", [v_pk_fma_f16(v[0], v[1], v[2], v[3])]),
|
||||
"VALU_pk_add_i16": ("v_pk_add_i16", [v_pk_add_i16(v[0], v[1], v[2])]),
|
||||
|
||||
# VOP3 - misc
|
||||
"VALU_bfe_u32": ("v_bfe_u32", [v_bfe_u32(v[0], v[1], 0, 8)]),
|
||||
"VALU_bfi_b32": ("v_bfi_b32", [v_bfi_b32(v[0], v[1], v[2], v[3])]),
|
||||
"VALU_alignbit": ("v_alignbit_b32", [v_alignbit_b32(v[0], v[1], v[2], 4)]),
|
||||
"VALU_fma_f32": ("v_fma_f32", [v_fma_f32(v[0], v[1], v[2], v[3])]),
|
||||
|
||||
# VOP3 - integer add variants (used by tinygrad kernels)
|
||||
"VALU_add3": ("v_add3_u32", [v_add3_u32(v[0], v[1], v[2], v[3])]),
|
||||
"VALU_xad": ("v_xad_u32", [v_xad_u32(v[0], v[1], v[2], v[3])]),
|
||||
"VALU_lshl_or": ("v_lshl_or_b32", [v_lshl_or_b32(v[0], v[1], 4, v[2])]),
|
||||
"VALU_add_nc": ("v_add_nc_u32", [v_add_nc_u32_e32(v[0], v[1], v[2])]),
|
||||
|
||||
# VOP3 - carry-out adds (used for 64-bit address calculation)
|
||||
"VALU_add_co": ("v_add_co_u32", [v_add_co_u32(v[0], SrcEnum.VCC_LO, v[1], v[2])]),
|
||||
"VALU_add_co_ci": ("v_add_co_ci_u32", [v_add_co_ci_u32_e32(v[0], v[1], v[2])]),
|
||||
|
||||
# VOPD - dual issue (used by tinygrad kernels)
|
||||
"VALU_dual_add": ("v_dual_add_f32", [v_dual_add_f32(v[0], v[1], v[2], v[3], v[4], v[5])]),
|
||||
"VALU_dual_mul": ("v_dual_mul_f32", [v_dual_mul_f32(v[0], v[1], v[2], v[3], v[4], v[5])]),
|
||||
|
||||
# VOP2 - fmac
|
||||
"VALU_fmac": ("v_fmac_f32", [v_fmac_f32_e32(v[0], v[1], v[0])]),
|
||||
|
||||
# DOT products
|
||||
"VALU_dot2": ("v_dot2_f16_f16", [v_dot2_f16_f16(v[0], v[1], v[2], v[3])]),
|
||||
|
||||
# WMMA - wave matrix multiply accumulate
|
||||
"VALU_wmma_f32_f16": ("v_wmma_f32_16x16x16_f16", [v_wmma_f32_16x16x16_f16(v[0:7], v[8:15], v[16:23], v[0:7])]),
|
||||
"VALU_wmma_f16_f16": ("v_wmma_f16_16x16x16_f16", [v_wmma_f16_16x16x16_f16(v[0:7], v[8:15], v[16:23], v[0:7])]),
|
||||
"VALU_wmma_i32_iu8": ("v_wmma_i32_16x16x16_iu8", [v_wmma_i32_16x16x16_iu8(v[0:7], v[8:11], v[12:15], v[0:7])]),
|
||||
|
||||
# Permlane operations - cross-lane data movement
|
||||
# NOTE: permlane64 produces NO SQTT packets in wave32 mode (it's for wave64 pairs)
|
||||
# NOTE: permlane16/x16 produce VALUINST packets (no specific InstOp)
|
||||
"VALU_permlane16": ("v_permlane16_b32", [v_permlane16_b32(v[0], v[1], s[2], s[3])]),
|
||||
"VALU_permlanex16": ("v_permlanex16_b32", [v_permlanex16_b32(v[0], v[1], s[2], s[3])]),
|
||||
|
||||
# Interpolation - used in graphics shaders (produces InstOp 0x12 VINTERP)
|
||||
"VINTERP_p10": ("v_interp_p10_f32", [v_interp_p10_f32(v[0], v[1], v[2], v[3])]),
|
||||
"VINTERP_p2": ("v_interp_p2_f32", [v_interp_p2_f32(v[0], v[1], v[2], v[3])]),
|
||||
|
||||
# Barrier - wave synchronization
|
||||
# NOTE: s_barrier produces NO SQTT instruction packets (with 1 wave, it's essentially a no-op)
|
||||
"SALU_barrier": ("s_barrier", [s_barrier()]),
|
||||
|
||||
# LDS atomics
|
||||
"LDS_atomic_add": ("ds_add_u32", [
|
||||
v_mov_b32_e32(v[0], 0), # LDS address
|
||||
v_mov_b32_e32(v[1], 1), # data to add
|
||||
ds_add_u32(addr=v[0], data0=v[1]),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]),
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# GLOBAL ATOMICS - access real buffer passed via kernarg
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# GLOBAL atomic add 32-bit (0x28 GLOBAL_ATOMIC)
|
||||
"GLOBAL_atomic_add": ("global_atomic_add_u32", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL), # load buf addr from kernarg
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0), # offset = 0
|
||||
v_mov_b32_e32(v[1], 1), # data to add
|
||||
global_atomic_add_u32(addr=v[0], data=v[1], saddr=s[2]),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
# GLOBAL atomic add 64-bit
|
||||
"GLOBAL_atomic_add64": ("global_atomic_add_u64", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
v_mov_b32_e32(v[2], 1),
|
||||
v_mov_b32_e32(v[3], 0),
|
||||
global_atomic_add_u64(addr=v[0], data=v[2:3], saddr=s[2]),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# MEMORY INSTRUCTIONS - access real buffer passed via kernarg
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# SMEM (0x1) - scalar memory load from buffer
|
||||
"SMEM_load": ("s_load_b32", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL), # load buf addr from kernarg
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
s_load_b32(s[4], s[2], 0, soffset=SrcEnum.NULL), # load from buffer
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]),
|
||||
|
||||
# GLOBAL load (0x21 GLOBAL_LOAD) - global memory load
|
||||
"GLOBAL_load": ("global_load_b32", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL), # load buf addr from kernarg
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0), # offset = 0
|
||||
global_load_b32(v[1], addr=v[0], saddr=s[2], offset=0), # load from buffer
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
# GLOBAL store (0x24 GLOBAL_STORE) - global memory store
|
||||
"GLOBAL_store": ("global_store_b32", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL), # load buf addr from kernarg
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0), # offset = 0
|
||||
v_mov_b32_e32(v[1], 42), # data to store
|
||||
global_store_b32(addr=v[0], data=v[1], saddr=s[2], offset=0), # store to buffer
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
# GLOBAL 8-bit load/store
|
||||
"GLOBAL_load8": ("global_load_u8", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
global_load_u8(v[1], addr=v[0], saddr=s[2], offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
"GLOBAL_store8": ("global_store_b8", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
v_mov_b32_e32(v[1], 42),
|
||||
global_store_b8(addr=v[0], data=v[1], saddr=s[2], offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
# GLOBAL 16-bit load/store
|
||||
"GLOBAL_load16": ("global_load_u16", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
global_load_u16(v[1], addr=v[0], saddr=s[2], offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
"GLOBAL_store16": ("global_store_b16", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
v_mov_b32_e32(v[1], 42),
|
||||
global_store_b16(addr=v[0], data=v[1], saddr=s[2], offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
# LDS load (0x29 LDS_LOAD) - local data share read
|
||||
"LDS_load": ("ds_load_b32", [
|
||||
v_mov_b32_e32(v[0], 0), # LDS address = 0
|
||||
ds_load_b32(v[1], v[0], offset=0), # read from LDS
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]),
|
||||
|
||||
# LDS store (0x2b LDS_STORE) - local data share write
|
||||
"LDS_store": ("ds_store_b32", [
|
||||
v_mov_b32_e32(v[0], 0), # LDS address = 0
|
||||
v_mov_b32_e32(v[1], 42), # data to store
|
||||
ds_store_b32(v[0], v[1], offset=0), # write to LDS
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]),
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# WIDER MEMORY OPERATIONS - to discover more InstOp variants
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# GLOBAL 64-bit load
|
||||
"GLOBAL_load64": ("global_load_b64", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
global_load_b64(v[2:3], addr=v[0], saddr=s[2], offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
# GLOBAL 96-bit load
|
||||
"GLOBAL_load96": ("global_load_b96", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
global_load_b96(v[4:6], addr=v[0], saddr=s[2], offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
# GLOBAL 128-bit load
|
||||
"GLOBAL_load128": ("global_load_b128", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
global_load_b128(v[4:7], addr=v[0], saddr=s[2], offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
# GLOBAL 64-bit store
|
||||
"GLOBAL_store64": ("global_store_b64", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
v_mov_b32_e32(v[2], 42),
|
||||
v_mov_b32_e32(v[3], 43),
|
||||
global_store_b64(addr=v[0], data=v[2:3], saddr=s[2], offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
# GLOBAL 96-bit store
|
||||
"GLOBAL_store96": ("global_store_b96", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
v_mov_b32_e32(v[4], 42),
|
||||
v_mov_b32_e32(v[5], 43),
|
||||
v_mov_b32_e32(v[6], 44),
|
||||
global_store_b96(addr=v[0], data=v[4:6], saddr=s[2], offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
# GLOBAL 128-bit store
|
||||
"GLOBAL_store128": ("global_store_b128", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
v_mov_b32_e32(v[4], 42),
|
||||
v_mov_b32_e32(v[5], 43),
|
||||
v_mov_b32_e32(v[6], 44),
|
||||
v_mov_b32_e32(v[7], 45),
|
||||
global_store_b128(addr=v[0], data=v[4:7], saddr=s[2], offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# GLOBAL VADDR (vector-only addressing, saddr=NULL) - used by tinygrad kernels
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# GLOBAL VADDR load (all sizes use same opcode 0x22)
|
||||
"GLOBAL_VADDR_load": ("global_load_b32 vaddr", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
global_load_b32(v[4], addr=v[0:1], saddr=SrcEnum.NULL, offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
"GLOBAL_VADDR_load128": ("global_load_b128 vaddr", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
global_load_b128(v[4:7], addr=v[0:1], saddr=SrcEnum.NULL, offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
# GLOBAL VADDR stores (size encoded: 32->0x25, 64->0x26, 96->0x27, 128->0x28)
|
||||
"GLOBAL_VADDR_store": ("global_store_b32 vaddr", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
v_mov_b32_e32(v[4], 42),
|
||||
global_store_b32(addr=v[0:1], data=v[4], saddr=SrcEnum.NULL, offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
"GLOBAL_VADDR_store64": ("global_store_b64 vaddr", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
v_mov_b32_e32(v[4], 42),
|
||||
v_mov_b32_e32(v[5], 43),
|
||||
global_store_b64(addr=v[0:1], data=v[4:5], saddr=SrcEnum.NULL, offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
"GLOBAL_VADDR_store96": ("global_store_b96 vaddr", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
v_mov_b32_e32(v[4], 42),
|
||||
v_mov_b32_e32(v[5], 43),
|
||||
v_mov_b32_e32(v[6], 44),
|
||||
global_store_b96(addr=v[0:1], data=v[4:6], saddr=SrcEnum.NULL, offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
"GLOBAL_VADDR_store128": ("global_store_b128 vaddr", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
v_mov_b32_e32(v[4], 42),
|
||||
v_mov_b32_e32(v[5], 43),
|
||||
v_mov_b32_e32(v[6], 44),
|
||||
v_mov_b32_e32(v[7], 45),
|
||||
global_store_b128(addr=v[0:1], data=v[4:7], saddr=SrcEnum.NULL, offset=0),
|
||||
s_waitcnt(vmcnt=0),
|
||||
]),
|
||||
|
||||
# LDS 64-bit load
|
||||
"LDS_load64": ("ds_load_b64", [
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
ds_load_b64(v[2:3], v[0], offset=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]),
|
||||
|
||||
# LDS 128-bit load
|
||||
"LDS_load128": ("ds_load_b128", [
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
ds_load_b128(v[4:7], v[0], offset=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]),
|
||||
|
||||
# LDS 64-bit store
|
||||
"LDS_store64": ("ds_store_b64", [
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
v_mov_b32_e32(v[2], 42),
|
||||
v_mov_b32_e32(v[3], 43),
|
||||
ds_store_b64(v[0], v[2:3], offset=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]),
|
||||
|
||||
# LDS 128-bit store
|
||||
"LDS_store128": ("ds_store_b128", [
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
v_mov_b32_e32(v[4], 42),
|
||||
v_mov_b32_e32(v[5], 43),
|
||||
v_mov_b32_e32(v[6], 44),
|
||||
v_mov_b32_e32(v[7], 45),
|
||||
ds_store_b128(v[0], v[4:7], offset=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]),
|
||||
|
||||
# MESSAGE (0x9) - s_sendmsg
|
||||
"MESSAGE": ("s_sendmsg", [
|
||||
s_sendmsg(0), # send message 0 (NOP message)
|
||||
]),
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# FLAT MEMORY - uses 64-bit virtual address in VGPRs
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# FLAT load - load using 64-bit address from buffer
|
||||
"FLAT_load": ("flat_load_b32", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL), # load buf addr from kernarg
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]), # addr lo
|
||||
v_mov_b32_e32(v[1], s[3]), # addr hi
|
||||
flat_load_b32(v[2], addr=v[0:1]),
|
||||
s_waitcnt(vmcnt=0, lgkmcnt=0),
|
||||
]),
|
||||
|
||||
# FLAT store
|
||||
"FLAT_store": ("flat_store_b32", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
v_mov_b32_e32(v[2], 42),
|
||||
flat_store_b32(addr=v[0:1], data=v[2]),
|
||||
s_waitcnt(vmcnt=0, lgkmcnt=0),
|
||||
]),
|
||||
|
||||
# FLAT 64-bit
|
||||
"FLAT_load64": ("flat_load_b64", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
flat_load_b64(v[2:3], addr=v[0:1]),
|
||||
s_waitcnt(vmcnt=0, lgkmcnt=0),
|
||||
]),
|
||||
|
||||
"FLAT_store64": ("flat_store_b64", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
v_mov_b32_e32(v[4], 42),
|
||||
v_mov_b32_e32(v[5], 43),
|
||||
flat_store_b64(addr=v[0:1], data=v[4:5]),
|
||||
s_waitcnt(vmcnt=0, lgkmcnt=0),
|
||||
]),
|
||||
|
||||
# FLAT 96-bit
|
||||
"FLAT_load96": ("flat_load_b96", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
flat_load_b96(v[4:6], addr=v[0:1]),
|
||||
s_waitcnt(vmcnt=0, lgkmcnt=0),
|
||||
]),
|
||||
|
||||
"FLAT_store96": ("flat_store_b96", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
v_mov_b32_e32(v[4], 42),
|
||||
v_mov_b32_e32(v[5], 43),
|
||||
v_mov_b32_e32(v[6], 44),
|
||||
flat_store_b96(addr=v[0:1], data=v[4:6]),
|
||||
s_waitcnt(vmcnt=0, lgkmcnt=0),
|
||||
]),
|
||||
|
||||
# FLAT 128-bit
|
||||
"FLAT_load128": ("flat_load_b128", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
flat_load_b128(v[4:7], addr=v[0:1]),
|
||||
s_waitcnt(vmcnt=0, lgkmcnt=0),
|
||||
]),
|
||||
|
||||
"FLAT_store128": ("flat_store_b128", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
v_mov_b32_e32(v[4], 42),
|
||||
v_mov_b32_e32(v[5], 43),
|
||||
v_mov_b32_e32(v[6], 44),
|
||||
v_mov_b32_e32(v[7], 45),
|
||||
flat_store_b128(addr=v[0:1], data=v[4:7]),
|
||||
s_waitcnt(vmcnt=0, lgkmcnt=0),
|
||||
]),
|
||||
|
||||
# FLAT 8/16-bit stores
|
||||
"FLAT_store8": ("flat_store_b8", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
v_mov_b32_e32(v[2], 42),
|
||||
flat_store_b8(addr=v[0:1], data=v[2]),
|
||||
s_waitcnt(vmcnt=0, lgkmcnt=0),
|
||||
]),
|
||||
|
||||
"FLAT_store16": ("flat_store_b16", [
|
||||
s_load_b64(s[2:3], s[0], 0, soffset=SrcEnum.NULL),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
v_mov_b32_e32(v[0], s[2]),
|
||||
v_mov_b32_e32(v[1], s[3]),
|
||||
v_mov_b32_e32(v[2], 42),
|
||||
flat_store_b16(addr=v[0:1], data=v[2]),
|
||||
s_waitcnt(vmcnt=0, lgkmcnt=0),
|
||||
]),
|
||||
|
||||
}
|
||||
|
||||
|
||||
def run_with_retry(instructions: list, max_attempts: int = 20) -> tuple[list[tuple[int, list[bytes]]], list[list], set, int]:
|
||||
"""Run instructions multiple times to collect InstOp variants.
|
||||
|
||||
Memory ops produce different InstOp values (0x2x vs 0x5x) depending on which SIMD executes them:
|
||||
- 0x2x range: wave ran on traced SIMD (matched)
|
||||
- 0x5x range: wave ran on other SIMD (not matched)
|
||||
|
||||
Returns list of (traced_simd, blobs) tuples, all_packets, all_ops, max_valuinst_count.
|
||||
"""
|
||||
all_ops = set()
|
||||
all_runs: list[tuple[int, list[bytes]]] = []
|
||||
all_packets = []
|
||||
max_valuinst = 0
|
||||
SQTT_SIMD_SEL.value = 0 # only trace SIMD 0
|
||||
for _ in range(max_attempts):
|
||||
blobs = run_asm_sqtt(instructions)
|
||||
packets = decode_all_blobs(blobs)
|
||||
# get ops and valuinst from all SIMDs
|
||||
ops = set()
|
||||
valuinst_count = 0
|
||||
for simd in [0, 1, 2, 3]:
|
||||
ops.update(get_inst_ops(packets, traced_simd=simd))
|
||||
valuinst_count = max(valuinst_count, count_valuinst(packets, traced_simd=simd))
|
||||
all_runs.append((0, blobs))
|
||||
all_packets.append(packets)
|
||||
all_ops.update(ops)
|
||||
max_valuinst = max(max_valuinst, valuinst_count)
|
||||
return all_runs, all_packets, all_ops, max_valuinst
|
||||
|
||||
def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception], dict[str, int]]:
|
||||
"""Run all instruction tests and collect InstOp values."""
|
||||
discovered: dict[int, set[str]] = {}
|
||||
failures: dict[str, Exception] = {}
|
||||
valuinst_tests: dict[str, int] = {} # tests that produced VALUINST packets
|
||||
|
||||
for test_name, (instr_name, instructions) in INSTRUCTION_TESTS.items():
|
||||
try:
|
||||
all_runs, _, ops, valuinst_count = run_with_retry(instructions)
|
||||
|
||||
for op in ops:
|
||||
if op not in discovered:
|
||||
discovered[op] = set()
|
||||
discovered[op].add(f"{test_name}")
|
||||
|
||||
if valuinst_count > 0:
|
||||
valuinst_tests[test_name] = valuinst_count
|
||||
|
||||
if DEBUG >= 2:
|
||||
print(f"\n{'─'*60}")
|
||||
print(f"{test_name} ({instr_name}): ops={[hex(op) for op in sorted(ops)]}")
|
||||
|
||||
# collect wave patterns from traced SIMD runs (group by exact timing)
|
||||
patterns: dict[tuple, list] = {} # pattern (types + timing) -> list of (wave_packets, t0)
|
||||
for traced_simd, blobs in all_runs:
|
||||
for blob in blobs:
|
||||
packets = decode_all_blobs([blob])
|
||||
wave_packets = get_wave_packets(packets)
|
||||
# only include runs where wave ran on traced SIMD
|
||||
ws = next((p for p in wave_packets if isinstance(p, WAVESTART)), None)
|
||||
if ws and ws.simd == traced_simd and wave_packets:
|
||||
t0 = wave_packets[0]._time
|
||||
# pattern includes types AND normalized timing
|
||||
pattern = tuple((type(p).__name__, p._time - t0) for p in wave_packets)
|
||||
if pattern not in patterns:
|
||||
patterns[pattern] = []
|
||||
patterns[pattern].append((wave_packets, t0))
|
||||
|
||||
if patterns:
|
||||
counts = {p: len(runs) for p, runs in patterns.items()}
|
||||
most_common = max(counts, key=counts.get)
|
||||
count = counts[most_common]
|
||||
total = sum(counts.values())
|
||||
print(f"\n=== most common pattern ({count}/{total} runs) ===")
|
||||
wave_packets, t0 = patterns[most_common][0]
|
||||
last_time = t0
|
||||
for p in wave_packets:
|
||||
print(format_packet(p, last_time, t0))
|
||||
last_time = p._time
|
||||
if len(patterns) > 1:
|
||||
print(f"\n variations: {len(patterns)} unique timing patterns")
|
||||
|
||||
if DEBUG >= 3:
|
||||
for traced_simd, blobs in all_runs:
|
||||
print(f"\n=== traced simd={traced_simd} ===")
|
||||
print_blobs(blobs, wave_only=False)
|
||||
if DEBUG >= 1:
|
||||
status = colored("✓", "green") if ops else (colored("V", "cyan") if valuinst_count > 0 else colored("∅", "yellow"))
|
||||
ops_str = ", ".join(hex(op) for op in sorted(ops)) if ops else "none"
|
||||
valuinst_str = f" valuinst={valuinst_count}" if valuinst_count > 0 and not ops else ""
|
||||
print(f" {status} {test_name:25s} ops=[{ops_str}]{valuinst_str}")
|
||||
|
||||
except Exception as e:
|
||||
failures[test_name] = e
|
||||
if DEBUG >= 1:
|
||||
print(f" {colored('✗', 'red')} {test_name:25s} FAILED: {e}")
|
||||
|
||||
return discovered, failures, valuinst_tests
|
||||
|
||||
|
||||
def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception], valuinst_tests: dict[str, int]) -> None:
|
||||
"""Print discovery summary."""
|
||||
known_ops = {e.value for e in InstOp}
|
||||
discovered_ops = set(discovered.keys())
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("DISCOVERED INSTOP VALUES")
|
||||
print("=" * 60)
|
||||
|
||||
for op in sorted(discovered_ops):
|
||||
try:
|
||||
name = InstOp(op).name
|
||||
status = colored("known", "green")
|
||||
except ValueError:
|
||||
name = f"UNKNOWN"
|
||||
status = colored("NEW!", "yellow")
|
||||
|
||||
sources = ", ".join(sorted(discovered[op]))
|
||||
print(f" 0x{op:02x} {name:20s} ({status}) <- {sources}")
|
||||
|
||||
# VALUINST tests (instructions that only produce VALUINST, not INST packets)
|
||||
valuinst_only = {k: v for k, v in valuinst_tests.items() if not any(k in tests for tests in discovered.values())}
|
||||
if valuinst_only:
|
||||
print("\n" + "=" * 60)
|
||||
print(colored("VALUINST-ONLY INSTRUCTIONS (no InstOp, use VALUINST packet)", "cyan"))
|
||||
print("=" * 60)
|
||||
for test_name, count in sorted(valuinst_only.items()):
|
||||
print(f" {test_name}: {count} VALUINST packets")
|
||||
|
||||
# Missing from enum
|
||||
missing = known_ops - discovered_ops
|
||||
if missing:
|
||||
print("\n" + "=" * 60)
|
||||
print("ENUM VALUES NOT DISCOVERED")
|
||||
print("=" * 60)
|
||||
print("(need memory ops: SMEM, VMEM, LDS)")
|
||||
for op in sorted(missing):
|
||||
print(f" 0x{op:02x} {InstOp(op).name}")
|
||||
|
||||
# New values to add
|
||||
new_ops = discovered_ops - known_ops
|
||||
if new_ops:
|
||||
print("\n" + "=" * 60)
|
||||
print(colored("NEW INSTOP VALUES TO ADD TO ENUM", "yellow"))
|
||||
print("=" * 60)
|
||||
for op in sorted(new_ops):
|
||||
sources = ", ".join(sorted(discovered[op]))
|
||||
print(f" {op:#04x}: \"{sources}\",")
|
||||
|
||||
# Stats
|
||||
print("\n" + "=" * 60)
|
||||
print("STATISTICS")
|
||||
print("=" * 60)
|
||||
print(f" Tests run: {len(INSTRUCTION_TESTS)}")
|
||||
print(f" Tests passed: {len(INSTRUCTION_TESTS) - len(failures)}")
|
||||
print(f" Tests failed: {len(failures)}")
|
||||
print(f" Known ops: {len(known_ops)}")
|
||||
print(f" Discovered: {len(discovered_ops)}")
|
||||
if known_ops:
|
||||
print(f" Coverage: {len(discovered_ops & known_ops)}/{len(known_ops)} ({100*len(discovered_ops & known_ops)//len(known_ops)}%)")
|
||||
print(f" New ops found: {len(new_ops)}")
|
||||
print(f" VALUINST-only: {len(valuinst_only)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("SQTT InstOp Discovery Tool")
|
||||
print("=" * 60)
|
||||
print(f"Testing {len(INSTRUCTION_TESTS)} instruction categories...\n")
|
||||
|
||||
discovered, failures, valuinst_tests = discover_all_instops()
|
||||
print_summary(discovered, failures, valuinst_tests)
|
||||
289
extra/assembly/amd/test/discover_instops_tensor.py
Normal file
289
extra/assembly/amd/test/discover_instops_tensor.py
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
#!/usr/bin/env python3
|
||||
"""SQTT InstOp discovery from tinygrad-generated kernels.
|
||||
|
||||
Runs various tinygrad operations and captures SQTT traces to find new InstOp values.
|
||||
|
||||
Requires profiling enabled:
|
||||
echo 'profile_standard' | sudo tee /sys/class/drm/card1/device/power_dpm_force_performance_level
|
||||
|
||||
Run with: DEBUG=1 python extra/assembly/amd/test/discover_instops_tensor.py
|
||||
For full traces: DEBUG=2 python extra/assembly/amd/test/discover_instops_tensor.py
|
||||
"""
|
||||
import os
|
||||
os.environ["SQTT"] = "1"
|
||||
os.environ["PROFILE"] = "1"
|
||||
os.environ["SQTT_LIMIT_SE"] = "2" # Force work to traced SE only
|
||||
os.environ["SQTT_TOKEN_EXCLUDE"] = "3784" # Exclude noisy packet types
|
||||
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, SQTT_SIMD_SEL
|
||||
|
||||
from extra.assembly.amd.sqtt import InstOp, decode, INST, WAVESTART, WAVEEND
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# HELPERS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def get_inst_ops_from_blobs(blobs: list[bytes]) -> set[int]:
|
||||
"""Extract all InstOp values from SQTT blobs."""
|
||||
ops = set()
|
||||
for blob in blobs:
|
||||
packets = decode(blob)
|
||||
in_wave = False
|
||||
for p in packets:
|
||||
if isinstance(p, WAVESTART):
|
||||
in_wave = True
|
||||
if in_wave and isinstance(p, INST):
|
||||
ops.add(p.op if isinstance(p.op, int) else p.op.value)
|
||||
if isinstance(p, WAVEEND):
|
||||
in_wave = False
|
||||
return ops
|
||||
|
||||
def run_and_capture(fn, attempts: int = 5) -> tuple[set[int], list[bytes]]:
|
||||
"""Run a function multiple times and collect SQTT traces."""
|
||||
dev = Device["AMD"]
|
||||
all_ops = set()
|
||||
all_blobs = []
|
||||
SQTT_SIMD_SEL.value = 0
|
||||
|
||||
for _ in range(attempts):
|
||||
dev.profile_events.clear()
|
||||
fn()
|
||||
blobs = [ev.blob for ev in dev.profile_events if isinstance(ev, ProfileSQTTEvent)]
|
||||
ops = get_inst_ops_from_blobs(blobs)
|
||||
all_ops.update(ops)
|
||||
all_blobs.extend(blobs)
|
||||
|
||||
return all_ops, all_blobs
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# TENSOR OPERATIONS TO TEST
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
TENSOR_TESTS: dict[str, tuple[str, callable]] = {
|
||||
# Basic arithmetic
|
||||
"add_f32": ("tensor add float32", lambda: (Tensor.rand(1024) + Tensor.rand(1024)).realize()),
|
||||
"mul_f32": ("tensor mul float32", lambda: (Tensor.rand(1024) * Tensor.rand(1024)).realize()),
|
||||
"sub_f32": ("tensor sub float32", lambda: (Tensor.rand(1024) - Tensor.rand(1024)).realize()),
|
||||
"div_f32": ("tensor div float32", lambda: (Tensor.rand(1024) / (Tensor.rand(1024) + 0.1)).realize()),
|
||||
|
||||
# Transcendental
|
||||
"exp_f32": ("tensor exp float32", lambda: Tensor.rand(1024).exp().realize()),
|
||||
"log_f32": ("tensor log float32", lambda: (Tensor.rand(1024) + 0.1).log().realize()),
|
||||
"sqrt_f32": ("tensor sqrt float32", lambda: Tensor.rand(1024).sqrt().realize()),
|
||||
"sin_f32": ("tensor sin float32", lambda: Tensor.rand(1024).sin().realize()),
|
||||
"cos_f32": ("tensor cos float32", lambda: Tensor.rand(1024).cos().realize()),
|
||||
"tanh_f32": ("tensor tanh float32", lambda: Tensor.rand(1024).tanh().realize()),
|
||||
"sigmoid_f32": ("tensor sigmoid float32", lambda: Tensor.rand(1024).sigmoid().realize()),
|
||||
|
||||
# Reductions
|
||||
"sum_f32": ("tensor sum float32", lambda: Tensor.rand(1024).sum().realize()),
|
||||
"max_f32": ("tensor max float32", lambda: Tensor.rand(1024).max().realize()),
|
||||
"mean_f32": ("tensor mean float32", lambda: Tensor.rand(1024).mean().realize()),
|
||||
|
||||
# Matmul - small
|
||||
"matmul_small": ("matmul 32x32", lambda: (Tensor.rand(32, 32) @ Tensor.rand(32, 32)).realize()),
|
||||
|
||||
# Matmul - medium (might use WMMA)
|
||||
"matmul_medium": ("matmul 128x128", lambda: (Tensor.rand(128, 128) @ Tensor.rand(128, 128)).realize()),
|
||||
|
||||
# Matmul - larger (more likely to use WMMA)
|
||||
"matmul_large": ("matmul 256x256", lambda: (Tensor.rand(256, 256) @ Tensor.rand(256, 256)).realize()),
|
||||
|
||||
# Different dtypes
|
||||
"add_f16": ("tensor add float16", lambda: (Tensor.rand(1024, dtype=dtypes.float16) + Tensor.rand(1024, dtype=dtypes.float16)).realize()),
|
||||
"mul_f16": ("tensor mul float16", lambda: (Tensor.rand(1024, dtype=dtypes.float16) * Tensor.rand(1024, dtype=dtypes.float16)).realize()),
|
||||
"matmul_f16": ("matmul float16 128x128", lambda: (Tensor.rand(128, 128, dtype=dtypes.float16) @ Tensor.rand(128, 128, dtype=dtypes.float16)).realize()),
|
||||
|
||||
# Integer ops
|
||||
"add_i32": ("tensor add int32", lambda: (Tensor.randint(1024, high=1000) + Tensor.randint(1024, high=1000)).realize()),
|
||||
"mul_i32": ("tensor mul int32", lambda: (Tensor.randint(1024, high=100) * Tensor.randint(1024, high=100)).realize()),
|
||||
|
||||
# Bitwise
|
||||
"and_i32": ("tensor bitwise and", lambda: (Tensor.randint(1024, high=1000) & Tensor.randint(1024, high=1000)).realize()),
|
||||
"or_i32": ("tensor bitwise or", lambda: (Tensor.randint(1024, high=1000) | Tensor.randint(1024, high=1000)).realize()),
|
||||
"xor_i32": ("tensor bitwise xor", lambda: (Tensor.randint(1024, high=1000) ^ Tensor.randint(1024, high=1000)).realize()),
|
||||
"lshift_i32": ("tensor left shift", lambda: (Tensor.randint(1024, high=1000) << 2).realize()),
|
||||
"rshift_i32": ("tensor right shift", lambda: (Tensor.randint(1024, high=1000) >> 2).realize()),
|
||||
|
||||
# Comparisons
|
||||
"cmp_eq": ("tensor compare eq", lambda: (Tensor.rand(1024) == 0.5).realize()),
|
||||
"cmp_lt": ("tensor compare lt", lambda: (Tensor.rand(1024) < 0.5).realize()),
|
||||
"cmp_gt": ("tensor compare gt", lambda: (Tensor.rand(1024) > 0.5).realize()),
|
||||
|
||||
# Where/select
|
||||
"where": ("tensor where", lambda: Tensor.rand(1024).where(Tensor.rand(1024), Tensor.rand(1024)).realize()),
|
||||
|
||||
# Reshaping/movement (may not generate interesting ops but let's check)
|
||||
"reshape": ("tensor reshape", lambda: Tensor.rand(32, 32).reshape(16, 64).realize()),
|
||||
"permute": ("tensor permute", lambda: Tensor.rand(32, 32).permute(1, 0).contiguous().realize()),
|
||||
"expand": ("tensor expand", lambda: Tensor.rand(1, 32).expand(32, 32).contiguous().realize()),
|
||||
|
||||
# Pad
|
||||
"pad": ("tensor pad", lambda: Tensor.rand(30, 30).pad(((1, 1), (1, 1))).realize()),
|
||||
|
||||
# Conv2D - small
|
||||
"conv2d_small": ("conv2d 3x3", lambda: Tensor.rand(1, 3, 32, 32).conv2d(Tensor.rand(8, 3, 3, 3)).realize()),
|
||||
|
||||
# Conv2D - larger
|
||||
"conv2d_medium": ("conv2d 3x3 64ch", lambda: Tensor.rand(1, 64, 32, 32).conv2d(Tensor.rand(64, 64, 3, 3)).realize()),
|
||||
|
||||
# Pooling
|
||||
"maxpool": ("max pool 2x2", lambda: Tensor.rand(1, 3, 32, 32).max_pool2d((2, 2)).realize()),
|
||||
"avgpool": ("avg pool 2x2", lambda: Tensor.rand(1, 3, 32, 32).avg_pool2d((2, 2)).realize()),
|
||||
|
||||
# Softmax
|
||||
"softmax": ("softmax", lambda: Tensor.rand(32, 128).softmax().realize()),
|
||||
|
||||
# LayerNorm-like
|
||||
"layernorm": ("layer norm pattern", lambda: _layernorm(Tensor.rand(32, 128))),
|
||||
|
||||
# BatchNorm-like
|
||||
"batchnorm": ("batch norm pattern", lambda: _batchnorm(Tensor.rand(1, 64, 32, 32))),
|
||||
|
||||
# Dropout-like (during training)
|
||||
"dropout": ("dropout pattern", lambda: (Tensor.rand(1024) * (Tensor.rand(1024) > 0.5)).realize()),
|
||||
|
||||
# Cast operations
|
||||
"cast_f32_to_f16": ("cast f32->f16", lambda: Tensor.rand(1024).cast(dtypes.float16).realize()),
|
||||
"cast_f16_to_f32": ("cast f16->f32", lambda: Tensor.rand(1024, dtype=dtypes.float16).cast(dtypes.float32).realize()),
|
||||
"cast_f32_to_i32": ("cast f32->i32", lambda: (Tensor.rand(1024) * 100).cast(dtypes.int32).realize()),
|
||||
"cast_i32_to_f32": ("cast i32->f32", lambda: Tensor.randint(1024, high=100).cast(dtypes.float32).realize()),
|
||||
|
||||
# Clamp/clip
|
||||
"clamp": ("tensor clamp", lambda: Tensor.rand(1024).clamp(0.2, 0.8).realize()),
|
||||
|
||||
# Abs/neg
|
||||
"abs": ("tensor abs", lambda: (Tensor.rand(1024) - 0.5).abs().realize()),
|
||||
"neg": ("tensor neg", lambda: (-Tensor.rand(1024)).realize()),
|
||||
|
||||
# Reciprocal
|
||||
"recip": ("tensor reciprocal", lambda: (Tensor.rand(1024) + 0.1).reciprocal().realize()),
|
||||
|
||||
# Power
|
||||
"pow2": ("tensor pow 2", lambda: (Tensor.rand(1024) ** 2).realize()),
|
||||
"pow3": ("tensor pow 3", lambda: (Tensor.rand(1024) ** 3).realize()),
|
||||
}
|
||||
|
||||
def _layernorm(x: Tensor) -> Tensor:
|
||||
"""Simple layer normalization pattern."""
|
||||
mean = x.mean(axis=-1, keepdim=True)
|
||||
var = ((x - mean) ** 2).mean(axis=-1, keepdim=True)
|
||||
return ((x - mean) / (var + 1e-5).sqrt()).realize()
|
||||
|
||||
def _batchnorm(x: Tensor) -> Tensor:
|
||||
"""Simple batch normalization pattern."""
|
||||
mean = x.mean(axis=(0, 2, 3), keepdim=True)
|
||||
var = ((x - mean) ** 2).mean(axis=(0, 2, 3), keepdim=True)
|
||||
return ((x - mean) / (var + 1e-5).sqrt()).realize()
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# DISCOVERY
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception]]:
|
||||
"""Run all tensor tests and collect InstOp values."""
|
||||
discovered: dict[int, set[str]] = {}
|
||||
failures: dict[str, Exception] = {}
|
||||
|
||||
for test_name, (desc, fn) in TENSOR_TESTS.items():
|
||||
try:
|
||||
ops, blobs = run_and_capture(fn)
|
||||
|
||||
for op in ops:
|
||||
if op not in discovered:
|
||||
discovered[op] = set()
|
||||
discovered[op].add(test_name)
|
||||
|
||||
if DEBUG >= 1:
|
||||
status = colored("✓", "green") if ops else colored("∅", "yellow")
|
||||
ops_str = ", ".join(hex(op) for op in sorted(ops)) if ops else "none"
|
||||
print(f" {status} {test_name:25s} [{desc:25s}] ops=[{ops_str}]")
|
||||
|
||||
if DEBUG >= 2 and blobs:
|
||||
# Show first wave trace
|
||||
for blob in blobs[:1]:
|
||||
packets = decode(blob)
|
||||
print(f" First blob: {len(blob)} bytes, {len(packets)} packets")
|
||||
|
||||
except Exception as e:
|
||||
failures[test_name] = e
|
||||
if DEBUG >= 1:
|
||||
print(f" {colored('✗', 'red')} {test_name:25s} FAILED: {e}")
|
||||
|
||||
return discovered, failures
|
||||
|
||||
|
||||
def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception]) -> None:
|
||||
"""Print discovery summary."""
|
||||
known_ops = {e.value for e in InstOp}
|
||||
discovered_ops = set(discovered.keys())
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("DISCOVERED INSTOP VALUES FROM TINYGRAD KERNELS")
|
||||
print("=" * 70)
|
||||
|
||||
for op in sorted(discovered_ops):
|
||||
try:
|
||||
name = InstOp(op).name
|
||||
status = colored("known", "green")
|
||||
except ValueError:
|
||||
name = "UNKNOWN"
|
||||
status = colored("NEW!", "yellow")
|
||||
|
||||
sources = ", ".join(sorted(discovered[op]))
|
||||
# Truncate sources if too long
|
||||
if len(sources) > 60:
|
||||
sources = sources[:57] + "..."
|
||||
print(f" 0x{op:02x} {name:20s} ({status}) <- {sources}")
|
||||
|
||||
# New values to add
|
||||
new_ops = discovered_ops - known_ops
|
||||
if new_ops:
|
||||
print("\n" + "=" * 70)
|
||||
print(colored("NEW INSTOP VALUES TO ADD TO ENUM", "yellow"))
|
||||
print("=" * 70)
|
||||
for op in sorted(new_ops):
|
||||
sources = ", ".join(sorted(discovered[op]))
|
||||
print(f" 0x{op:02x}: discovered from [{sources}]")
|
||||
|
||||
# Missing from enum (not discovered)
|
||||
missing = known_ops - discovered_ops
|
||||
if missing:
|
||||
print("\n" + "=" * 70)
|
||||
print("ENUM VALUES NOT DISCOVERED (may need specific instruction patterns)")
|
||||
print("=" * 70)
|
||||
for op in sorted(missing):
|
||||
print(f" 0x{op:02x} {InstOp(op).name}")
|
||||
|
||||
# Stats
|
||||
print("\n" + "=" * 70)
|
||||
print("STATISTICS")
|
||||
print("=" * 70)
|
||||
print(f" Tests run: {len(TENSOR_TESTS)}")
|
||||
print(f" Tests passed: {len(TENSOR_TESTS) - len(failures)}")
|
||||
print(f" Tests failed: {len(failures)}")
|
||||
print(f" Known ops: {len(known_ops)}")
|
||||
print(f" Discovered: {len(discovered_ops)}")
|
||||
if known_ops:
|
||||
coverage = len(discovered_ops & known_ops)
|
||||
print(f" Coverage: {coverage}/{len(known_ops)} ({100*coverage//len(known_ops)}%)")
|
||||
print(f" New ops found: {len(new_ops)}")
|
||||
|
||||
if failures:
|
||||
print("\n" + "=" * 70)
|
||||
print("FAILURES")
|
||||
print("=" * 70)
|
||||
for name, e in failures.items():
|
||||
print(f" {name}: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 70)
|
||||
print("SQTT InstOp Discovery from Tinygrad Kernels")
|
||||
print("=" * 70)
|
||||
print(f"Testing {len(TENSOR_TESTS)} tensor operations...\n")
|
||||
|
||||
discovered, failures = discover_all_instops()
|
||||
print_summary(discovered, failures)
|
||||
79
extra/assembly/amd/test/test_sqtt.py
Normal file
79
extra/assembly/amd/test/test_sqtt.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Tests for SQTT packet codec (no hardware required)."""
|
||||
import unittest
|
||||
from extra.assembly.amd.sqtt import (
|
||||
LAYOUT_HEADER, WAVESTART, WAVEEND, INST, NOP,
|
||||
decode, encode, PACKET_TYPES, OPCODE_TO_CLASS
|
||||
)
|
||||
|
||||
|
||||
class TestSQTTCodec(unittest.TestCase):
|
||||
"""Tests for SQTT encoder/decoder roundtrip."""
|
||||
|
||||
def test_roundtrip_simple(self):
|
||||
"""Test encode/decode roundtrip for simple packets."""
|
||||
test_packets = [
|
||||
LAYOUT_HEADER.from_raw(0x100),
|
||||
WAVESTART.from_raw(0x0),
|
||||
INST.from_raw(0x10), # delta=1
|
||||
INST.from_raw(0x10), # delta=1
|
||||
WAVEEND.from_raw(0x40), # delta=2
|
||||
]
|
||||
encoded = encode(test_packets)
|
||||
decoded = decode(encoded)
|
||||
|
||||
self.assertGreaterEqual(len(decoded), len(test_packets))
|
||||
for i, (orig, dec) in enumerate(zip(test_packets, decoded)):
|
||||
self.assertEqual(type(orig), type(dec), f"type mismatch at {i}")
|
||||
|
||||
def test_decode_empty(self):
|
||||
"""Test decoding empty data."""
|
||||
packets = decode(b'')
|
||||
self.assertEqual(packets, [])
|
||||
|
||||
def test_encode_empty(self):
|
||||
"""Test encoding empty list."""
|
||||
data = encode([])
|
||||
self.assertEqual(data, b'')
|
||||
|
||||
def test_all_packet_types_have_encoding(self):
|
||||
"""All packet types should have an encoding defined."""
|
||||
for pkt_cls in PACKET_TYPES:
|
||||
self.assertIsNotNone(pkt_cls._encoding, f"{pkt_cls.__name__} missing encoding")
|
||||
|
||||
def test_packet_from_raw(self):
|
||||
"""Test creating packets from raw values."""
|
||||
# INST with wave=5, op=0x21, delta=2
|
||||
raw = (0x21 << 13) | (5 << 8) | (2 << 4) | 0b010
|
||||
pkt = INST.from_raw(raw)
|
||||
self.assertEqual(pkt.wave, 5)
|
||||
self.assertEqual(pkt.op, 0x21)
|
||||
self.assertEqual(pkt.delta, 2)
|
||||
|
||||
|
||||
class TestDecodeRealBlob(unittest.TestCase):
|
||||
"""Test decoding real SQTT blobs from examples."""
|
||||
|
||||
def test_decode_example_file(self):
|
||||
"""Test decoding a real SQTT blob from examples."""
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
example_path = Path(__file__).parent.parent.parent.parent / "sqtt/examples/profile_plus_run_0.pkl"
|
||||
if not example_path.exists():
|
||||
self.skipTest(f"Example file not found: {example_path}")
|
||||
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
|
||||
with open(example_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
sqtt_events = [e for e in data if isinstance(e, ProfileSQTTEvent)]
|
||||
self.assertGreater(len(sqtt_events), 0, "No SQTT events in example")
|
||||
|
||||
packets = decode(sqtt_events[0].blob)
|
||||
self.assertGreater(len(packets), 0, "No packets decoded")
|
||||
# First packet should be LAYOUT_HEADER
|
||||
self.assertIsInstance(packets[0], LAYOUT_HEADER)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
3011
extra/assembly/amd/test/test_sqtt_compare.py
Normal file
3011
extra/assembly/amd/test/test_sqtt_compare.py
Normal file
File diff suppressed because it is too large
Load diff
545
extra/assembly/amd/test/test_sqtt_correct.py
Normal file
545
extra/assembly/amd/test/test_sqtt_correct.py
Normal file
|
|
@ -0,0 +1,545 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Tests for SQTT emulator correctness against known hardware patterns.
|
||||
|
||||
NOTE: This file only tests NOP and VALU behavior. For WMMA/DP/trans tests,
|
||||
see test_sqtt_compare.py.
|
||||
|
||||
Run emulator tests: PYTHONPATH="." python3 extra/assembly/amd/test/test_sqtt_correct.py
|
||||
Run hardware tests: SQTT_HW=1 PYTHONPATH="." python3 extra/assembly/amd/test/test_sqtt_correct.py
|
||||
"""
|
||||
import os
|
||||
import unittest
|
||||
|
||||
USE_HW = os.environ.get("SQTT_HW", "0") == "1"
|
||||
|
||||
if USE_HW:
|
||||
os.environ["SQTT"] = "1"
|
||||
os.environ["PROFILE"] = "1"
|
||||
os.environ["SQTT_LIMIT_SE"] = "2"
|
||||
os.environ["SQTT_TOKEN_EXCLUDE"] = "3784"
|
||||
|
||||
from extra.assembly.amd.emu import SQTTState, decode_program, exec_wave, WaveState, LDSMem
|
||||
from extra.assembly.amd.sqtt import WAVESTART, WAVEEND
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, s_nop, s_endpgm, s_delay_alu
|
||||
from extra.assembly.amd.dsl import v
|
||||
|
||||
def assemble(instructions: list) -> bytes:
|
||||
return b''.join(inst.to_bytes() for inst in instructions)
|
||||
|
||||
def wrap_with_nops(instructions: list, nops=16) -> list:
|
||||
return instructions + [s_nop(0)]*nops + [s_endpgm()]
|
||||
|
||||
def get_wave_packets(packets: list) -> list:
|
||||
result, in_wave = [], False
|
||||
for p in packets:
|
||||
if isinstance(p, WAVESTART) and p.simd == 0:
|
||||
in_wave, result = True, [p]
|
||||
elif in_wave:
|
||||
result.append(p)
|
||||
if isinstance(p, WAVEEND): break
|
||||
return result
|
||||
|
||||
def get_timing_deltas(packets: list) -> list[tuple[str, int]]:
|
||||
skip_types = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3", "REG"}
|
||||
filtered = [p for p in packets if type(p).__name__ not in skip_types]
|
||||
if not filtered: return []
|
||||
result = [(type(filtered[0]).__name__, 0)]
|
||||
for i in range(1, len(filtered)):
|
||||
result.append((type(filtered[i]).__name__, filtered[i]._time - filtered[i-1]._time))
|
||||
return result
|
||||
|
||||
def run_emulator(instructions: list) -> list:
|
||||
code = assemble(instructions)
|
||||
program = decode_program(code)
|
||||
st = WaveState()
|
||||
st.exec_mask = (1 << 32) - 1
|
||||
lds = LDSMem(bytearray(65536))
|
||||
trace = SQTTState(wave_id=0, simd=0, cu=0)
|
||||
exec_wave(program, st, lds, 32, trace)
|
||||
return get_wave_packets(trace.packets)
|
||||
|
||||
def get_all_waves(packets: list) -> list[list]:
|
||||
"""Extract all WAVESTART..WAVEEND ranges on simd 0."""
|
||||
waves, in_wave, current = [], False, []
|
||||
for p in packets:
|
||||
if isinstance(p, WAVESTART) and p.simd == 0:
|
||||
in_wave, current = True, [p]
|
||||
elif in_wave:
|
||||
current.append(p)
|
||||
if isinstance(p, WAVEEND):
|
||||
waves.append(current)
|
||||
in_wave, current = False, []
|
||||
return waves
|
||||
|
||||
def run_hardware(instructions: list) -> list:
|
||||
from extra.assembly.amd.test.test_sqtt_hw import compile_asm_sqtt, run_prg_sqtt_batch
|
||||
from extra.assembly.amd.sqtt import decode
|
||||
from collections import Counter
|
||||
|
||||
prg = compile_asm_sqtt(instructions, alu_only=True)
|
||||
|
||||
for _ in range(10):
|
||||
blobs = run_prg_sqtt_batch(prg, n_runs=200)
|
||||
# Extract all waves from all blobs
|
||||
traces = []
|
||||
for blob in blobs:
|
||||
traces.extend(get_all_waves(decode(blob)))
|
||||
if not traces:
|
||||
continue
|
||||
# Find most common pattern
|
||||
delta_sets = [tuple(get_timing_deltas(t)) for t in traces]
|
||||
most_common = Counter(delta_sets).most_common(1)[0][0]
|
||||
for t in traces:
|
||||
if tuple(get_timing_deltas(t)) == most_common:
|
||||
return t
|
||||
return []
|
||||
|
||||
def run_sqtt(instructions: list, nops: int = 16) -> list:
|
||||
instructions = wrap_with_nops(instructions, nops=nops)
|
||||
return run_hardware(instructions) if USE_HW else run_emulator(instructions)
|
||||
|
||||
def get_deltas(instructions: list) -> tuple[list[int], list[int]]:
|
||||
"""Run and return (issue deltas, exec deltas).
|
||||
Issue = IMMEDIATE + VALUINST, Exec = ALUEXEC.
|
||||
Deltas are between consecutive packets of same stream."""
|
||||
deltas = get_timing_deltas(run_sqtt(instructions))
|
||||
time = 0
|
||||
issue_times, exec_times = [], []
|
||||
for ptype, delta in deltas:
|
||||
time += delta
|
||||
if ptype in ('IMMEDIATE', 'VALUINST'):
|
||||
issue_times.append(time)
|
||||
elif ptype == 'ALUEXEC':
|
||||
exec_times.append(time)
|
||||
issue = [issue_times[i] - issue_times[i-1] for i in range(1, len(issue_times))]
|
||||
execd = [exec_times[i] - exec_times[i-1] for i in range(1, len(exec_times))]
|
||||
return issue, execd
|
||||
|
||||
# ************************************ tests ************************************
|
||||
|
||||
class TestVALUChains(unittest.TestCase):
|
||||
"""VALU dependency chains."""
|
||||
def _chain(self, n, expected_issue, expected_exec):
|
||||
instrs = [v_mov_b32_e32(v[0], 1.0)] + [v_add_f32_e32(v[i], v[i-1], v[i-1]) for i in range(1, n)]
|
||||
issue, execd = get_deltas(instrs)
|
||||
self.assertEqual(issue[:n-1], expected_issue)
|
||||
if isinstance(expected_exec[0], list): self.assertIn(execd, expected_exec)
|
||||
else: self.assertEqual(execd, expected_exec)
|
||||
|
||||
def test_chain_2(self): self._chain(2, [1], [6])
|
||||
def test_chain_3(self): self._chain(3, [1, 1], [6, 5])
|
||||
def test_chain_4(self): self._chain(4, [1, 1, 1], [6, 5, 5])
|
||||
def test_chain_5(self): self._chain(5, [1, 1, 1, 1], [6, 5, 5, 9])
|
||||
def test_chain_6(self): self._chain(6, [1, 1, 1, 1, 1], [6, 5, 5, 9, 9])
|
||||
def test_chain_7(self): self._chain(7, [1, 1, 1, 1, 1, 1], [6, 5, 5, 5, 9, 9])
|
||||
def test_chain_8(self): self._chain(8, [1, 1, 1, 1, 1, 1, 1], [6, 5, 5, 5, 9, 9, 9])
|
||||
# NOTE: position 8 can be 5 or 9 depending on GPU variant
|
||||
def test_chain_12(self): self._chain(12, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [[6, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9], [6, 5, 5, 5, 5, 9, 9, 9, 5, 9, 9]])
|
||||
def test_chain_14(self): self._chain(14, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [[6, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9], [6, 5, 5, 5, 5, 9, 9, 9, 5, 9, 9, 9, 9]])
|
||||
# issue stalls start here
|
||||
def test_chain_15(self): self._chain(15, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3], [[6, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9], [6, 5, 5, 5, 5, 5, 9, 9, 5, 9, 9, 9, 9, 9]])
|
||||
def test_chain_16(self): self._chain(16, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 5], [[6, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9], [6, 5, 5, 5, 5, 5, 5, 9, 5, 9, 9, 9, 9, 9, 9]])
|
||||
def test_chain_18(self): self._chain(18, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 5, 5, 5], [6, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9])
|
||||
def test_chain_20(self): self._chain(20, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 5, 5, 5, 5, 5], [6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9])
|
||||
|
||||
|
||||
class TestVALUChainsWithWarmup(unittest.TestCase):
|
||||
"""VALU dependency chains with early VALUs to isolate warmup effects."""
|
||||
# just the first stupid VALU takes 6
|
||||
def _chain(self, n, warmup=True):
|
||||
instrs = [v_mov_b32_e32(v[0], 1.0), s_nop(100)] if warmup else [s_nop(100)]
|
||||
instrs += [v_mov_b32_e32(v[0], 1.0)] + [v_add_f32_e32(v[i], v[i-1], v[i-1]) for i in range(1, n)]
|
||||
issue, execd = get_deltas(instrs)
|
||||
return execd[1:] if warmup else execd
|
||||
|
||||
def test_warmup_chain_2(self): self.assertEqual(self._chain(2), [5])
|
||||
def test_warmup_chain_3(self): self.assertEqual(self._chain(3), [5, 5])
|
||||
def test_warmup_chain_4(self): self.assertEqual(self._chain(4), [5, 5, 5])
|
||||
def test_warmup_chain_5(self): self.assertEqual(self._chain(5), [5, 5, 5, 9])
|
||||
def test_warmup_chain_6(self): self.assertEqual(self._chain(6), [5, 5, 5, 5, 9])
|
||||
def test_warmup_chain_7(self): self.assertEqual(self._chain(7), [5, 5, 5, 5, 9, 9])
|
||||
def test_warmup_chain_8(self): self.assertEqual(self._chain(8), [5, 5, 5, 5, 9, 9, 9])
|
||||
|
||||
def test_cold_chain_2(self): self.assertEqual(self._chain(2, False), [6])
|
||||
def test_cold_chain_3(self): self.assertEqual(self._chain(3, False), [6, 5])
|
||||
def test_cold_chain_4(self): self.assertEqual(self._chain(4, False), [6, 5, 5])
|
||||
def test_cold_chain_5(self): self.assertEqual(self._chain(5, False), [6, 5, 5, 9])
|
||||
def test_cold_chain_6(self): self.assertEqual(self._chain(6, False), [6, 5, 5, 9, 9])
|
||||
def test_cold_chain_7(self): self.assertEqual(self._chain(7, False), [6, 5, 5, 5, 9, 9])
|
||||
def test_cold_chain_8(self): self.assertEqual(self._chain(8, False), [6, 5, 5, 5, 9, 9, 9])
|
||||
|
||||
|
||||
class TestVALUIndependent(unittest.TestCase):
|
||||
"""Independent VALU instructions."""
|
||||
def _ind(self, n, expected_exec):
|
||||
instrs = [v_mov_b32_e32(v[i], float(i)) for i in range(n)]
|
||||
issue, execd = get_deltas(instrs)
|
||||
self.assertEqual(issue[:n-1], [1]*(n-1))
|
||||
self.assertEqual(execd, expected_exec)
|
||||
|
||||
def test_ind_2(self): self._ind(2, [1])
|
||||
def test_ind_3(self): self._ind(3, [1, 1])
|
||||
def test_ind_4(self): self._ind(4, [1, 1, 1])
|
||||
def test_ind_5(self): self._ind(5, [1, 1, 1, 1])
|
||||
def test_ind_6(self): self._ind(6, [1, 1, 1, 1, 1])
|
||||
def test_ind_7(self): self._ind(7, [1, 1, 1, 1, 1, 1])
|
||||
def test_ind_8(self): self._ind(8, [1, 1, 1, 1, 1, 1, 1])
|
||||
|
||||
|
||||
class TestForwardingGap(unittest.TestCase):
|
||||
"""Producer + N independent instructions + consumer - tests forwarding window."""
|
||||
def _exec_deltas(self, n_gap):
|
||||
instrs = [v_mov_b32_e32(v[0], 1.0)]
|
||||
instrs += [v_mov_b32_e32(v[10+i], float(i)) for i in range(n_gap)]
|
||||
instrs += [v_add_f32_e32(v[1], v[0], v[0])]
|
||||
_, execd = get_deltas(instrs)
|
||||
return execd
|
||||
|
||||
def test_gap0(self): self.assertEqual(self._exec_deltas(0), [6])
|
||||
def test_gap1(self): self.assertEqual(self._exec_deltas(1), [1, 5])
|
||||
def test_gap2(self): self.assertEqual(self._exec_deltas(2), [1, 1, 4])
|
||||
def test_gap3(self): self.assertIn(self._exec_deltas(3), [[1, 1, 1, 3], [1, 1, 1, 4]])
|
||||
def test_gap4(self): self.assertIn(self._exec_deltas(4), [[1, 1, 1, 1, 3], [1, 1, 1, 1, 4]])
|
||||
def test_gap5(self): self.assertEqual(self._exec_deltas(5), [1, 1, 1, 1, 1, 4]) # anomaly
|
||||
def test_gap6(self): self.assertEqual(self._exec_deltas(6), [1, 1, 1, 1, 1, 1, 3])
|
||||
def test_gap7(self): self.assertEqual(self._exec_deltas(7), [1, 1, 1, 1, 1, 1, 1, 3])
|
||||
def test_gap8(self): self.assertEqual(self._exec_deltas(8), [1, 1, 1, 1, 1, 1, 1, 1, 3])
|
||||
def test_gap9(self): self.assertEqual(self._exec_deltas(9), [1, 1, 1, 1, 1, 1, 1, 1, 1, 3])
|
||||
|
||||
|
||||
class TestChainWithIndependentGap(unittest.TestCase):
|
||||
"""Chain of dependent VALUs with independent VALUs inserted before the last one.
|
||||
|
||||
Hardware observation: In a chain v0->v1->v2->v3->v4, if we insert N independent VALUs
|
||||
before v4, the forwarding behavior changes:
|
||||
- 0-1 independent VALUs: v4 cannot forward from v3 (delta=9)
|
||||
- 2+ independent VALUs: v4 can forward from v3 (delta=5)
|
||||
|
||||
This suggests forwarding eligibility depends on whether the direct source is in the ALU
|
||||
at issue time, not just at dispatch time.
|
||||
"""
|
||||
def _chain5_gap(self, n_ind):
|
||||
"""Chain v0->v1->v2->v3->v4 with N independent VALUs before v4. Returns v3->v4 delta."""
|
||||
instrs = [s_nop(100),
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_mov_b32_e32(v[1], v[0]),
|
||||
v_mov_b32_e32(v[2], v[1]),
|
||||
v_mov_b32_e32(v[3], v[2])]
|
||||
instrs += [v_mov_b32_e32(v[10+i], float(i)) for i in range(n_ind)]
|
||||
instrs += [v_mov_b32_e32(v[4], v[3])]
|
||||
_, execd = get_deltas(instrs)
|
||||
# Chain execs are at indices 0,1,2,3 and last one. Independent ones are in between.
|
||||
# v3->v4 delta = last exec time - 4th exec time (index 3)
|
||||
# With n_ind independent VALUs, execd has 4 + n_ind entries
|
||||
# We want delta between exec[3] (v3) and exec[4+n_ind-1] (v4)
|
||||
# Actually execd is already deltas, so we need absolute times
|
||||
time, exec_times = 0, []
|
||||
packets = run_sqtt(instrs)
|
||||
for ptype, delta in get_timing_deltas(packets):
|
||||
time += delta
|
||||
if ptype == 'ALUEXEC': exec_times.append(time)
|
||||
# v0,v1,v2,v3 are first 4, v4 is last
|
||||
return exec_times[-1] - exec_times[3]
|
||||
|
||||
def test_gap0(self): self.assertEqual(self._chain5_gap(0), 9)
|
||||
def test_gap1(self): self.assertEqual(self._chain5_gap(1), 9)
|
||||
def test_gap2(self): self.assertEqual(self._chain5_gap(2), 5)
|
||||
def test_gap3(self): self.assertEqual(self._chain5_gap(3), 5)
|
||||
def test_gap4(self): self.assertEqual(self._chain5_gap(4), 5)
|
||||
|
||||
|
||||
class TestVALULatency(unittest.TestCase):
|
||||
"""VALU latency depends on VGPR source reads.
|
||||
6 cycles: no VGPR source (constant only), stays 6 regardless of warmup
|
||||
8-11 cycles: VGPR source read, decreases with warmup (11->10->9->8)
|
||||
s_nop(0) after VALU immediately drops VGPR read latency to 8
|
||||
Anomalies:
|
||||
- 7 consecutive VALUs (no s_nop) causes +1 cycle penalty
|
||||
- n=0 or n=3 const VALUs + nop + vgpr = 9 cycles (not 8)
|
||||
"""
|
||||
def _get_latency(self, instrs):
|
||||
if not isinstance(instrs, list): instrs = [instrs]
|
||||
packets = run_sqtt(instrs)
|
||||
deltas = get_timing_deltas(packets)
|
||||
time, valu_times, exec_times = 0, [], []
|
||||
for ptype, delta in deltas:
|
||||
time += delta
|
||||
if ptype == 'VALUINST': valu_times.append(time)
|
||||
if ptype == 'ALUEXEC': exec_times.append(time)
|
||||
return exec_times[-1] - valu_times[-1] if valu_times and exec_times else None
|
||||
|
||||
# 6-cycle latency: no VGPR source (constant), always 6
|
||||
def test_const_single(self): self.assertEqual(self._get_latency(v_mov_b32_e32(v[0], 1.0)), 6)
|
||||
def test_const_literal(self): self.assertEqual(self._get_latency(v_mov_b32_e32(v[0], 565.0)), 6)
|
||||
def test_const_after_const(self): self.assertEqual(self._get_latency([v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[1], 2.0)]), 6)
|
||||
def test_const_after_nop(self): self.assertEqual(self._get_latency([v_mov_b32_e32(v[0], 1.0), s_nop(0), v_mov_b32_e32(v[1], 2.0)]), 6)
|
||||
|
||||
# VGPR read latency: cold start = 9
|
||||
def test_vgpr_cold(self): self.assertEqual(self._get_latency(v_mov_b32_e32(v[0], v[1])), 9)
|
||||
|
||||
# VGPR read latency: warmup decreases 11->10->9->8
|
||||
def _vgpr_after_n_const(self, n):
|
||||
return self._get_latency([v_mov_b32_e32(v[i], float(i)) for i in range(n)] + [v_mov_b32_e32(v[10], v[99])])
|
||||
def test_vgpr_after_1_const(self): self.assertEqual(self._vgpr_after_n_const(1), 11)
|
||||
def test_vgpr_after_2_const(self): self.assertEqual(self._vgpr_after_n_const(2), 10)
|
||||
def test_vgpr_after_3_const(self): self.assertEqual(self._vgpr_after_n_const(3), 9)
|
||||
def test_vgpr_after_4_const(self): self.assertIn(self._vgpr_after_n_const(4), [8, 9])
|
||||
def test_vgpr_after_5_const(self): self.assertIn(self._vgpr_after_n_const(5), [8, 9])
|
||||
def test_vgpr_after_6_const(self): self.assertEqual(self._vgpr_after_n_const(6), 9) # anomaly
|
||||
def test_vgpr_after_7_const(self): self.assertEqual(self._vgpr_after_n_const(7), 8)
|
||||
def test_vgpr_after_8_const(self): self.assertEqual(self._vgpr_after_n_const(8), 8)
|
||||
|
||||
# s_nop(0) immediately drops VGPR read latency to 8 (or 9 on some variants)
|
||||
def test_vgpr_nop_warmup(self): self.assertIn(self._get_latency([v_mov_b32_e32(v[0], 1.0), s_nop(0), v_mov_b32_e32(v[1], v[99])]), [8, 9])
|
||||
|
||||
# s_nop + vgpr read: latency depends on # of const VALUs before nop
|
||||
def _n_const_nop_vgpr(self, n):
|
||||
"""N const VALUs + s_nop(0) + vgpr read."""
|
||||
instrs = [v_mov_b32_e32(v[i], float(i)) for i in range(n)]
|
||||
instrs += [s_nop(0)]
|
||||
instrs += [v_mov_b32_e32(v[10], v[99])]
|
||||
return self._get_latency(instrs)
|
||||
def test_0_const_nop_vgpr(self): self.assertEqual(self._n_const_nop_vgpr(0), 9)
|
||||
def test_1_const_nop_vgpr(self): self.assertIn(self._n_const_nop_vgpr(1), [8, 9])
|
||||
def test_2_const_nop_vgpr(self): self.assertIn(self._n_const_nop_vgpr(2), [8, 9])
|
||||
def test_3_const_nop_vgpr(self): self.assertEqual(self._n_const_nop_vgpr(3), 9) # anomaly
|
||||
def test_4_const_nop_vgpr(self): self.assertEqual(self._n_const_nop_vgpr(4), 8)
|
||||
def test_5_const_nop_vgpr(self): self.assertEqual(self._n_const_nop_vgpr(5), 8)
|
||||
def test_6_const_nop_vgpr(self): self.assertEqual(self._n_const_nop_vgpr(6), 8)
|
||||
def test_7_const_nop_vgpr(self): self.assertEqual(self._n_const_nop_vgpr(7), 8)
|
||||
|
||||
|
||||
class TestChainWithNop(unittest.TestCase):
|
||||
"""Dependency chain with s_nop between instructions."""
|
||||
def _test(self, nop_val, expected_issue, expected_exec):
|
||||
issue, execd = get_deltas([v_mov_b32_e32(v[0], 1.0), s_nop(nop_val), v_add_f32_e32(v[1], v[0], v[0])])
|
||||
self.assertEqual(issue[:2], expected_issue)
|
||||
if isinstance(expected_exec[0], list): self.assertIn(execd, expected_exec)
|
||||
else: self.assertEqual(execd, expected_exec)
|
||||
|
||||
def test_nop0(self): self._test(0, [3, 1], [[6], [7]])
|
||||
def test_nop1(self): self._test(1, [4, 1], [[7], [8]])
|
||||
def test_nop2(self): self._test(2, [5, 1], [9])
|
||||
def test_nop3(self): self._test(3, [6, 1], [9])
|
||||
def test_nop4(self): self._test(4, [11, 1], [10])
|
||||
def test_nop5(self): self._test(5, [12, 1], [11])
|
||||
|
||||
|
||||
class TestIndWithNop(unittest.TestCase):
|
||||
"""Independent instructions with s_nop between."""
|
||||
def _test(self, nop_val, expected_issue, expected_exec):
|
||||
issue, execd = get_deltas([v_mov_b32_e32(v[0], 1.0), s_nop(nop_val), v_mov_b32_e32(v[1], 2.0)])
|
||||
self.assertEqual(issue[:2], expected_issue)
|
||||
self.assertEqual(execd, expected_exec)
|
||||
|
||||
def test_nop0(self): self._test(0, [3, 1], [4])
|
||||
def test_nop1(self): self._test(1, [4, 1], [5])
|
||||
def test_nop3(self): self._test(3, [6, 1], [7])
|
||||
def test_nop4(self): self._test(4, [11, 1], [8])
|
||||
def test_nop5(self): self._test(5, [12, 1], [9])
|
||||
|
||||
|
||||
class TestChain3NopMid(unittest.TestCase):
|
||||
"""3-instruction chain with s_nop in middle."""
|
||||
def _test(self, nop_val, expected_issue, expected_exec):
|
||||
issue, execd = get_deltas([
|
||||
v_mov_b32_e32(v[0], 1.0), v_add_f32_e32(v[1], v[0], v[0]),
|
||||
s_nop(nop_val), v_add_f32_e32(v[2], v[1], v[1])])
|
||||
self.assertEqual(issue[:3], expected_issue)
|
||||
self.assertEqual(execd, expected_exec)
|
||||
|
||||
def test_nop0(self): self._test(0, [1, 3, 1], [6, 5])
|
||||
def test_nop1(self): self._test(1, [1, 4, 1], [6, 5])
|
||||
def test_nop2(self): self._test(2, [1, 5, 1], [6, 5])
|
||||
def test_nop3(self): self._test(3, [1, 10, 1], [6, 5])
|
||||
|
||||
|
||||
class TestInd3NopMid(unittest.TestCase):
|
||||
"""3 independent instructions with s_nop in middle."""
|
||||
def _test(self, nop_val, expected_issue, expected_exec):
|
||||
issue, execd = get_deltas([
|
||||
v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[1], 2.0),
|
||||
s_nop(nop_val), v_mov_b32_e32(v[2], 3.0)])
|
||||
self.assertEqual(issue[:3], expected_issue)
|
||||
self.assertEqual(execd, expected_exec)
|
||||
|
||||
def test_nop0(self): self._test(0, [1, 3, 1], [1, 4])
|
||||
def test_nop1(self): self._test(1, [1, 4, 1], [1, 5])
|
||||
def test_nop2(self): self._test(2, [1, 5, 1], [1, 6])
|
||||
def test_nop3(self): self._test(3, [1, 10, 1], [1, 7])
|
||||
|
||||
|
||||
class TestSNopDelay(unittest.TestCase):
|
||||
"""Single s_nop delay between two independent v_movs.
|
||||
s_nop(n) delays n+1 cycles, plus +4 extra for n in [11, 22].
|
||||
Exec delta = n + 4 (baseline) + 4 (if 11 <= n <= 22)."""
|
||||
def _test(self, n, expected):
|
||||
_, execd = get_deltas([v_mov_b32_e32(v[0], 1.0), s_nop(n), v_mov_b32_e32(v[1], 2.0)])
|
||||
if isinstance(expected, list): self.assertIn(execd[0], expected)
|
||||
else: self.assertEqual(execd, [expected])
|
||||
|
||||
def test_snop_0(self): self._test(0, 4)
|
||||
def test_snop_1(self): self._test(1, 5)
|
||||
def test_snop_2(self): self._test(2, 6)
|
||||
def test_snop_3(self): self._test(3, 7)
|
||||
def test_snop_4(self): self._test(4, 8)
|
||||
def test_snop_5(self): self._test(5, 9)
|
||||
def test_snop_6(self): self._test(6, 10)
|
||||
def test_snop_7(self): self._test(7, 11)
|
||||
def test_snop_10(self): self._test(10, 14)
|
||||
def test_snop_11(self): self._test(11, 19) # +4 extra starts here
|
||||
def test_snop_15(self): self._test(15, 23)
|
||||
def test_snop_22(self): self._test(22, 30) # +4 extra ends here
|
||||
def test_snop_23(self): self._test(23, 27)
|
||||
def test_snop_31(self): self._test(31, 35)
|
||||
def test_snop_32(self): self._test(32, 36)
|
||||
def test_snop_63(self): self._test(63, [67, 71])
|
||||
|
||||
|
||||
class TestVALUExecWithNop(unittest.TestCase):
|
||||
"""Single VALU followed by s_nop - measures VALUINST to ALUEXEC delay."""
|
||||
def _get_delay(self, instrs, nops=16):
|
||||
deltas = get_timing_deltas(run_sqtt(instrs, nops=nops))
|
||||
time, valu_time, exec_time = 0, None, None
|
||||
for ptype, delta in deltas:
|
||||
time += delta
|
||||
if ptype == 'VALUINST' and valu_time is None: valu_time = time
|
||||
if ptype == 'ALUEXEC' and exec_time is None: exec_time = time
|
||||
return exec_time - valu_time
|
||||
|
||||
# Boundary: s_nop(0-3) = 6 cycles, s_nop(4+) = 10 cycles
|
||||
def test_nop0(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(0)]), 6)
|
||||
def test_nop1(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(1)]), 6)
|
||||
def test_nop2(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(2)]), 6)
|
||||
def test_nop3(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(3)]), 6)
|
||||
def test_nop4(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(4)]), 10)
|
||||
def test_nop5(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(5)]), 10)
|
||||
def test_nop6(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(6)]), 10)
|
||||
def test_nop7(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(7)]), 10)
|
||||
def test_nop8(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(8)]), 10)
|
||||
def test_nop9(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(9)]), 10)
|
||||
def test_nop10(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(10)]), 10)
|
||||
# No nop = slow path, one s_nop(0) padding = fast path
|
||||
def test_no_padding(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0)], nops=0), 10)
|
||||
def test_one_padding(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0)], nops=1), 6)
|
||||
# Multiple s_nop(0)s don't accumulate - still fast path
|
||||
def test_nop0_x2(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(0), s_nop(0)]), 6)
|
||||
# First nop determines path: s_nop(0) then s_nop(4) = fast, s_nop(4) then s_nop(0) = slow
|
||||
def test_nop0_nop4(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(0), s_nop(4)]), 6)
|
||||
def test_nop4_nop0(self): self.assertEqual(self._get_delay([v_mov_b32_e32(v[0], 1.0), s_nop(4), s_nop(0)]), 10)
|
||||
|
||||
|
||||
class TestDelayALU(unittest.TestCase):
|
||||
"""s_delay_alu behavior - helps understand hardware pipeline latencies.
|
||||
|
||||
s_delay_alu(simm16) where simm16 encodes:
|
||||
instid0[3:0] = dependency on VALU N instructions back (1-4), 0=none
|
||||
skip[6:4] = skip count for second dependency
|
||||
instid1[10:7] = second dependency
|
||||
|
||||
Key insight: s_delay_alu tells hardware to wait for a previous VALU to complete.
|
||||
The hardware determines how many cycles to stall based on pipeline state.
|
||||
"""
|
||||
def _exec_delta(self, instrs):
|
||||
"""Return exec delta for last instruction."""
|
||||
_, execd = get_deltas(instrs)
|
||||
return execd[-1] if execd else None
|
||||
|
||||
# Direct dependency (producer -> consumer), instid0=1 means "wait for VALU 1 back"
|
||||
def test_direct_no_delay(self):
|
||||
# Without s_delay_alu: 6 cycles
|
||||
self.assertEqual(self._exec_delta([v_mov_b32_e32(v[0], 1.0), v_add_f32_e32(v[1], v[0], v[0])]), 6)
|
||||
|
||||
def test_direct_delay1(self):
|
||||
# With s_delay_alu(instid0=1): 7-8 cycles (+1 from the delay instruction)
|
||||
self.assertIn(self._exec_delta([v_mov_b32_e32(v[0], 1.0), s_delay_alu(simm16=1), v_add_f32_e32(v[1], v[0], v[0])]), [7, 8])
|
||||
|
||||
def test_direct_delay2(self):
|
||||
# instid0=2 doesn't apply (only 1 VALU back), so no extra delay
|
||||
self.assertEqual(self._exec_delta([v_mov_b32_e32(v[0], 1.0), s_delay_alu(simm16=2), v_add_f32_e32(v[1], v[0], v[0])]), 6)
|
||||
|
||||
def test_direct_delay3(self):
|
||||
self.assertEqual(self._exec_delta([v_mov_b32_e32(v[0], 1.0), s_delay_alu(simm16=3), v_add_f32_e32(v[1], v[0], v[0])]), 6)
|
||||
|
||||
def test_direct_delay4(self):
|
||||
self.assertEqual(self._exec_delta([v_mov_b32_e32(v[0], 1.0), s_delay_alu(simm16=4), v_add_f32_e32(v[1], v[0], v[0])]), 6)
|
||||
|
||||
# With 1 independent instruction between producer and consumer
|
||||
def test_gap1_delay1(self):
|
||||
# instid0=1 waits for the independent instruction (not the producer)
|
||||
instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), s_delay_alu(simm16=1), v_add_f32_e32(v[1], v[0], v[0])]
|
||||
self.assertEqual(self._exec_delta(instrs), 8)
|
||||
|
||||
def test_gap1_delay2(self):
|
||||
# instid0=2 waits for the producer (2 VALUs back)
|
||||
instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), s_delay_alu(simm16=2), v_add_f32_e32(v[1], v[0], v[0])]
|
||||
self.assertIn(self._exec_delta(instrs), [6, 7])
|
||||
|
||||
def test_gap1_delay3(self):
|
||||
# instid0=3 doesn't apply (only 2 VALUs back)
|
||||
instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), s_delay_alu(simm16=3), v_add_f32_e32(v[1], v[0], v[0])]
|
||||
self.assertEqual(self._exec_delta(instrs), 5)
|
||||
|
||||
# With 2 independent instructions between
|
||||
def test_gap2_delay1(self):
|
||||
instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), v_mov_b32_e32(v[6], 6.0),
|
||||
s_delay_alu(simm16=1), v_add_f32_e32(v[1], v[0], v[0])]
|
||||
self.assertEqual(self._exec_delta(instrs), 7)
|
||||
|
||||
def test_gap2_delay2(self):
|
||||
instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), v_mov_b32_e32(v[6], 6.0),
|
||||
s_delay_alu(simm16=2), v_add_f32_e32(v[1], v[0], v[0])]
|
||||
self.assertEqual(self._exec_delta(instrs), 7)
|
||||
|
||||
def test_gap2_delay3(self):
|
||||
# instid0=3 waits for the producer (3 VALUs back)
|
||||
instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), v_mov_b32_e32(v[6], 6.0),
|
||||
s_delay_alu(simm16=3), v_add_f32_e32(v[1], v[0], v[0])]
|
||||
self.assertIn(self._exec_delta(instrs), [5, 6])
|
||||
|
||||
def test_gap2_delay4(self):
|
||||
instrs = [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[5], 5.0), v_mov_b32_e32(v[6], 6.0),
|
||||
s_delay_alu(simm16=4), v_add_f32_e32(v[1], v[0], v[0])]
|
||||
self.assertEqual(self._exec_delta(instrs), 4)
|
||||
|
||||
|
||||
class TestNopTimingSensitivity(unittest.TestCase):
|
||||
"""Forwarding behavior has 128-cycle periodicity.
|
||||
|
||||
Hardware observation: when nop_cycles % 128 is in [72, 75], chain_6 gets 5 forwards
|
||||
instead of 4. This 4-cycle window repeats every 128 cycles, suggesting alignment
|
||||
with some hardware scheduling period (possibly wave scheduler or cache).
|
||||
|
||||
Windows found: nop 72-75, 200-203, 328-331, 456-459, ...
|
||||
"""
|
||||
def _chain6_fwd_count(self, nop_size):
|
||||
"""Count initial consecutive forwards for a 6-instruction chain after s_nop(n)."""
|
||||
instrs = [s_nop(nop_size), v_mov_b32_e32(v[99], 1.0)]
|
||||
instrs += [v_mov_b32_e32(v[0], 1.0)]
|
||||
for i in range(1, 6):
|
||||
instrs += [v_mov_b32_e32(v[i], v[i-1])]
|
||||
_, execd = get_deltas(instrs)
|
||||
chain_deltas = execd[1:]
|
||||
fwd_count = 0
|
||||
for d in chain_deltas:
|
||||
if d == 5: fwd_count += 1
|
||||
else: break
|
||||
return fwd_count
|
||||
|
||||
# Normal case: 4 forwards
|
||||
def test_nop71(self): self.assertEqual(self._chain6_fwd_count(71), 4)
|
||||
def test_nop76(self): self.assertEqual(self._chain6_fwd_count(76), 4)
|
||||
def test_nop199(self): self.assertEqual(self._chain6_fwd_count(199), 4)
|
||||
def test_nop204(self): self.assertEqual(self._chain6_fwd_count(204), 4)
|
||||
|
||||
# Anomaly window at nop % 128 == 72-75: 5 forwards on RDNA3, 4 on other variants
|
||||
def test_nop72(self): self.assertIn(self._chain6_fwd_count(72), [4, 5])
|
||||
def test_nop75(self): self.assertIn(self._chain6_fwd_count(75), [4, 5])
|
||||
def test_nop200(self): self.assertIn(self._chain6_fwd_count(200), [4, 5])
|
||||
def test_nop203(self): self.assertIn(self._chain6_fwd_count(203), [4, 5])
|
||||
def test_nop328(self): self.assertIn(self._chain6_fwd_count(328), [4, 5])
|
||||
def test_nop331(self): self.assertIn(self._chain6_fwd_count(331), [4, 5])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
463
extra/assembly/amd/test/test_sqtt_hw.py
Normal file
463
extra/assembly/amd/test/test_sqtt_hw.py
Normal file
|
|
@ -0,0 +1,463 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Hardware tests for SQTT decoder - validates decoding of real SQTT streams.
|
||||
|
||||
Run with: python -m pytest extra/assembly/amd/test/test_sqtt_hw.py -v -s
|
||||
Requires AMD GPU with SQTT support.
|
||||
|
||||
For pretty trace output: DEBUG=2 python -m pytest extra/assembly/amd/test/test_sqtt_hw.py -v -s
|
||||
"""
|
||||
import os
|
||||
os.environ["SQTT"] = "1"
|
||||
os.environ["PROFILE"] = "1"
|
||||
os.environ["SQTT_ITRACE_SE_MASK"] = "1" # Enable instruction tracing on SE0
|
||||
os.environ["SQTT_LIMIT_SE"] = "2" # Force work to traced SE only
|
||||
|
||||
import unittest
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.runtime.ops_amd import AMDProgram, ProfileSQTTEvent
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, v_mul_f32_e32, s_mov_b32, s_add_u32, s_nop, s_waitcnt, s_endpgm
|
||||
from extra.assembly.amd.dsl import v, s
|
||||
from extra.assembly.amd.sqtt import decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, VALUINST, ALUEXEC, VMEMEXEC, InstOp, AluSrc, MemSrc
|
||||
|
||||
dev = Device["AMD"]
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PRETTY PRINTING
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
PACKET_COLORS = {
|
||||
"INST": "WHITE", "VALUINST": "BLACK",
|
||||
"VMEMEXEC": "yellow", "ALUEXEC": "yellow",
|
||||
"IMMEDIATE": "YELLOW", "IMMEDIATE_MASK": "YELLOW",
|
||||
"WAVERDY": "cyan", "WAVEALLOC": "cyan",
|
||||
"WAVEEND": "blue", "WAVESTART": "blue",
|
||||
"PERF": "magenta",
|
||||
"EVENT": "red", "EVENT_BIG": "red",
|
||||
"REG": "green",
|
||||
"LAYOUT_HEADER": "white",
|
||||
"TS_DELTA_SHORT": "BLACK", "NOP": "BLACK", "TS_WAVE_STATE": "BLACK",
|
||||
"SNAPSHOT": "white", "TS_DELTA_OR_MARK": "BLACK",
|
||||
"TS_DELTA_S8_W3": "BLACK", "TS_DELTA_S5_W2": "BLACK", "TS_DELTA_S5_W3": "BLACK",
|
||||
"UTILCTR": "green",
|
||||
}
|
||||
|
||||
def format_packet(p, last_time: int = 0, time_offset: int = 0) -> str:
|
||||
"""Format a packet for pretty printing."""
|
||||
name = type(p).__name__
|
||||
color = PACKET_COLORS.get(name, "white")
|
||||
|
||||
fields = []
|
||||
if isinstance(p, INST):
|
||||
op = p.op
|
||||
op_name = op.name if isinstance(op, InstOp) else f"0x{op:02x}"
|
||||
fields = [f"wave={p.wave}", f"op={op_name}"]
|
||||
if p.flag1: fields.append("flag1")
|
||||
if p.flag2: fields.append("flag2")
|
||||
elif isinstance(p, VALUINST):
|
||||
fields = [f"wave={p.wave}"]
|
||||
if p.flag: fields.append("flag")
|
||||
elif isinstance(p, ALUEXEC):
|
||||
src_name = p.src.name if isinstance(p.src, AluSrc) else f"{p.src}"
|
||||
fields = [f"src={src_name}"]
|
||||
elif isinstance(p, VMEMEXEC):
|
||||
src_name = p.src.name if isinstance(p.src, MemSrc) else f"{p.src}"
|
||||
fields = [f"src={src_name}"]
|
||||
elif isinstance(p, WAVESTART):
|
||||
fields = [f"wave={p.wave}", f"simd={p.simd}", f"cu={p.cu}"]
|
||||
elif isinstance(p, WAVEEND):
|
||||
fields = [f"wave={p.wave}", f"simd={p.simd}", f"cu={p.cu}"]
|
||||
elif hasattr(p, '_values'):
|
||||
# Format hex fields appropriately
|
||||
hex_fields = {'snap', 'val32'}
|
||||
fields = [f"{k}=0x{v:x}" if k in hex_fields else f"{k}={v}" for k, v in p._values.items() if not k.startswith('_') and k != 'delta']
|
||||
|
||||
return colored(f"{name:18s}", color) + " " + ", ".join(fields)
|
||||
|
||||
def get_wave_packets(packets: list) -> list:
|
||||
"""Extract packets from WAVESTART to WAVEEND, filtering pure timing packets."""
|
||||
skip_types = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3"}
|
||||
result = []
|
||||
in_wave = False
|
||||
for p in packets:
|
||||
name = type(p).__name__
|
||||
if isinstance(p, WAVESTART):
|
||||
in_wave = True
|
||||
if in_wave and name not in skip_types:
|
||||
result.append(p)
|
||||
if isinstance(p, WAVEEND):
|
||||
in_wave = False
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def print_wave_trace(packets: list) -> None:
|
||||
"""Print packets from WAVESTART to WAVEEND with normalized time."""
|
||||
wave_packets = get_wave_packets(packets)
|
||||
if not wave_packets:
|
||||
return
|
||||
time_offset = wave_packets[0]._time
|
||||
last_time = time_offset
|
||||
for p in wave_packets:
|
||||
print(format_packet(p, last_time, time_offset))
|
||||
last_time = p._time
|
||||
|
||||
def print_blobs(blobs: list[bytes], wave_only: bool = True) -> None:
|
||||
"""Print traces for all blobs. wave_only=True filters to WAVESTART..WAVEEND only."""
|
||||
for i, blob in enumerate(blobs):
|
||||
packets = decode(blob)
|
||||
print(f"\n--- Blob {i}: {len(blob)} bytes, {len(packets)} packets ---")
|
||||
if wave_only:
|
||||
print_wave_trace(packets)
|
||||
else:
|
||||
print_all_packets(packets)
|
||||
|
||||
def print_all_packets(packets: list) -> None:
|
||||
"""Print all packets, filtering out pure timing packets."""
|
||||
skip_types = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3"}
|
||||
if not packets: return
|
||||
time_offset = packets[0]._time
|
||||
last_time = time_offset
|
||||
for p in packets:
|
||||
if type(p).__name__ not in skip_types:
|
||||
print(format_packet(p, last_time, time_offset))
|
||||
last_time = p._time
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# ASSEMBLY HELPERS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def assemble(instructions: list) -> bytes:
|
||||
return b''.join(inst.to_bytes() for inst in instructions)
|
||||
|
||||
def wrap_with_nops(instructions: list, nops=16) -> list:
|
||||
"""Add epilogue for clean SQTT timing.
|
||||
|
||||
Need enough NOPs to cover long-latency ops (DP: 42 cycles, WMMA: 47 cycles).
|
||||
With 64 NOPs, the IMMEDIATE phase extends to cover these completions.
|
||||
"""
|
||||
return instructions + [s_nop(0)]*nops + [s_endpgm()]
|
||||
|
||||
def compile_asm_sqtt(instructions: list, alu_only: bool = False) -> AMDProgram:
|
||||
"""Compile instructions to an AMDProgram for SQTT tracing.
|
||||
|
||||
Args:
|
||||
instructions: List of instructions to compile
|
||||
alu_only: If True, use minimal kernel config with no kernargs/LDS/scratch
|
||||
Returns:
|
||||
Compiled AMDProgram ready to run
|
||||
"""
|
||||
compiler = HIPCompiler(dev.arch)
|
||||
# Add NOPs before s_endpgm to flush pipeline and get clean timing
|
||||
code = assemble(instructions)
|
||||
byte_str = ', '.join(f'0x{b:02x}' for b in code)
|
||||
|
||||
if alu_only:
|
||||
asm_src = f""".text
|
||||
.globl test
|
||||
.p2align 8
|
||||
.type test,@function
|
||||
test:
|
||||
.byte {byte_str}
|
||||
|
||||
.rodata
|
||||
.p2align 6
|
||||
.amdhsa_kernel test
|
||||
# basic memory
|
||||
.amdhsa_group_segment_fixed_size 0
|
||||
.amdhsa_private_segment_fixed_size 0
|
||||
.amdhsa_kernarg_size 32
|
||||
.amdhsa_enable_private_segment 0
|
||||
# register usage
|
||||
.amdhsa_next_free_vgpr 64
|
||||
.amdhsa_next_free_sgpr 8
|
||||
# RSRC1
|
||||
.amdhsa_wavefront_size32 1
|
||||
.amdhsa_memory_ordered 1
|
||||
.amdhsa_forward_progress 1
|
||||
# this is key
|
||||
.amdhsa_workgroup_processor_mode 0
|
||||
.end_amdhsa_kernel
|
||||
|
||||
.amdgpu_metadata
|
||||
---
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 0
|
||||
amdhsa.kernels:
|
||||
- .name: test
|
||||
.symbol: test.kd
|
||||
.kernarg_segment_size: 0
|
||||
.group_segment_fixed_size: 0
|
||||
.private_segment_fixed_size: 0
|
||||
.kernarg_segment_align: 8
|
||||
.wavefront_size: 32
|
||||
.sgpr_count: 8
|
||||
.vgpr_count: 64
|
||||
.max_flat_workgroup_size: 1024
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
"""
|
||||
else:
|
||||
asm_src = f""".text
|
||||
.globl test
|
||||
.p2align 8
|
||||
.type test,@function
|
||||
test:
|
||||
.byte {byte_str}
|
||||
|
||||
.rodata
|
||||
.p2align 6
|
||||
.amdhsa_kernel test
|
||||
.amdhsa_next_free_vgpr 8
|
||||
.amdhsa_next_free_sgpr 16
|
||||
.amdhsa_wavefront_size32 1
|
||||
.amdhsa_user_sgpr_kernarg_segment_ptr 1
|
||||
.amdhsa_kernarg_size 8
|
||||
.amdhsa_group_segment_fixed_size 0
|
||||
.amdhsa_private_segment_fixed_size 0
|
||||
.end_amdhsa_kernel
|
||||
|
||||
.amdgpu_metadata
|
||||
---
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 0
|
||||
amdhsa.kernels:
|
||||
- .name: test
|
||||
.symbol: test.kd
|
||||
.kernarg_segment_size: 8
|
||||
.group_segment_fixed_size: 0
|
||||
.private_segment_fixed_size: 0
|
||||
.kernarg_segment_align: 8
|
||||
.wavefront_size: 32
|
||||
.sgpr_count: 16
|
||||
.vgpr_count: 8
|
||||
.max_flat_workgroup_size: 1024
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
"""
|
||||
|
||||
lib = compiler.compile(asm_src)
|
||||
return AMDProgram(dev, "test", lib)
|
||||
|
||||
def run_asm_sqtt(instructions: list, n_lanes: int = 1, alu_only: bool = False) -> list[bytes]:
|
||||
"""Compile and run instructions on AMD hardware, return SQTT blobs.
|
||||
|
||||
Args:
|
||||
instructions: List of instructions to run
|
||||
n_lanes: Number of lanes to use
|
||||
alu_only: If True, use minimal kernel config with no kernargs/LDS/scratch
|
||||
"""
|
||||
prg = compile_asm_sqtt(instructions, alu_only=alu_only)
|
||||
return run_prg_sqtt(prg, n_lanes=n_lanes, alu_only=alu_only)
|
||||
|
||||
def run_prg_sqtt(prg: AMDProgram, n_lanes: int = 1, alu_only: bool = False) -> list[bytes]:
|
||||
"""Run a compiled AMDProgram and return SQTT blobs.
|
||||
|
||||
Args:
|
||||
prg: Compiled AMDProgram to run
|
||||
n_lanes: Number of lanes to use
|
||||
alu_only: If True, don't allocate kernarg buffer
|
||||
"""
|
||||
dev.profile_events.clear()
|
||||
if alu_only:
|
||||
prg(global_size=(1, 1, 1), local_size=(n_lanes, 1, 1), wait=True)
|
||||
else:
|
||||
out_gpu = dev.allocator.alloc(2048)
|
||||
prg(out_gpu, global_size=(1, 1, 1), local_size=(n_lanes, 1, 1), wait=True)
|
||||
return [ev.blob for ev in dev.profile_events if isinstance(ev, ProfileSQTTEvent)]
|
||||
|
||||
def run_prg_sqtt_batch(prg: AMDProgram, n_runs: int, n_lanes: int = 1) -> list[bytes]:
|
||||
"""Run a compiled AMDProgram N times in a single queue submission and return SQTT blobs.
|
||||
|
||||
This builds one queue with N kernel executions, submits it once, and collects SQTT.
|
||||
All N runs are captured in the same SQTT trace, reducing startup jitter.
|
||||
|
||||
Args:
|
||||
prg: Compiled AMDProgram to run
|
||||
n_runs: Number of times to execute the kernel in the queue
|
||||
n_lanes: Number of lanes to use
|
||||
Returns:
|
||||
List of SQTT blobs (one per shader engine)
|
||||
"""
|
||||
from typing import cast
|
||||
from tinygrad.runtime.ops_amd import AMDComputeQueue, SQTT_ITRACE_SE_MASK
|
||||
from tinygrad.device import Compiled
|
||||
import struct
|
||||
|
||||
dev.profile_events.clear()
|
||||
|
||||
# Build queue with sqtt_start, N kernel executions, sqtt_stop
|
||||
kernargs = prg.fill_kernargs([], ())
|
||||
q = cast(AMDComputeQueue, dev.hw_compute_queue_t())
|
||||
q.wait(dev.timeline_signal, dev.timeline_value - 1).memory_barrier()
|
||||
q.sqtt_start(dev.sqtt_buffers)
|
||||
|
||||
# Execute kernel N times
|
||||
for _ in range(n_runs):
|
||||
q.exec(prg, kernargs, (1, 1, 1), (n_lanes, 1, 1))
|
||||
|
||||
q.sqtt_stop(dev.sqtt_wptrs)
|
||||
q.signal(dev.timeline_signal, dev.next_timeline())
|
||||
q.submit(dev)
|
||||
dev.synchronize()
|
||||
|
||||
# Collect SQTT blobs
|
||||
blobs = []
|
||||
for se, buf in enumerate(dev.sqtt_buffers):
|
||||
wptr = (dev.sqtt_wptrs.cpu_view().view(fmt='I')[se] & 0x1FFFFFFF) * 32
|
||||
if dev.target[:2] == (11, 0): wptr -= ((buf.va_addr // 32) & 0x1FFFFFFF) * 32
|
||||
if wptr > 0 and wptr <= buf.size:
|
||||
dev.allocator._copyout(sqtt_mv:=memoryview(bytearray(wptr)), buf)
|
||||
resbuf = (struct.pack('<Q', 0x11 | (4 << 13) | (0xf << 16) | (se << 24)) + bytes(sqtt_mv)) if dev.target[0] == 9 else bytes(sqtt_mv)
|
||||
blobs.append(resbuf)
|
||||
|
||||
return blobs
|
||||
|
||||
def decode_all_blobs(blobs: list[bytes]) -> list:
|
||||
"""Decode all blobs and combine packets."""
|
||||
all_packets = []
|
||||
for blob in blobs:
|
||||
all_packets.extend(decode(blob))
|
||||
return all_packets
|
||||
|
||||
def get_inst_ops(packets: list, traced_simd: int | None = None) -> set:
|
||||
"""Extract all InstOp values from INST packets within WAVESTART..WAVEEND on traced SIMD."""
|
||||
ops = set()
|
||||
in_wave = False
|
||||
for p in packets:
|
||||
if isinstance(p, WAVESTART):
|
||||
in_wave = traced_simd is None or p.simd == traced_simd
|
||||
if in_wave and isinstance(p, INST):
|
||||
ops.add(p.op if isinstance(p.op, int) else p.op.value)
|
||||
if isinstance(p, WAVEEND):
|
||||
in_wave = False
|
||||
return ops
|
||||
|
||||
def count_valuinst(packets: list, traced_simd: int | None = None) -> int:
|
||||
"""Count VALUINST packets within WAVESTART..WAVEEND on traced SIMD."""
|
||||
count = 0
|
||||
in_wave = False
|
||||
for p in packets:
|
||||
if isinstance(p, WAVESTART):
|
||||
in_wave = traced_simd is None or p.simd == traced_simd
|
||||
if in_wave and isinstance(p, VALUINST):
|
||||
count += 1
|
||||
if isinstance(p, WAVEEND):
|
||||
in_wave = False
|
||||
return count
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# TESTS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
@unittest.skipIf(not hasattr(dev, 'profile_events'), "AMD device required")
|
||||
class TestSQTTDecode(unittest.TestCase):
|
||||
"""Test SQTT decoder with real hardware traces."""
|
||||
|
||||
def test_basic_structure(self):
|
||||
"""Verify basic SQTT stream structure: LAYOUT_HEADER, WAVESTART, instructions, WAVEEND."""
|
||||
blobs = run_asm_sqtt([v_mov_b32_e32(v[0], 0)])
|
||||
|
||||
self.assertGreater(len(blobs), 0, "No SQTT data captured")
|
||||
packets = decode_all_blobs(blobs)
|
||||
|
||||
self.assertGreater(len(packets), 0, "No packets decoded")
|
||||
self.assertGreater(len([p for p in packets if isinstance(p, LAYOUT_HEADER)]), 0, "No LAYOUT_HEADER packets")
|
||||
self.assertGreater(len([p for p in packets if isinstance(p, WAVESTART)]), 0, "No WAVESTART packets")
|
||||
self.assertGreater(len([p for p in packets if isinstance(p, WAVEEND)]), 0, "No WAVEEND packets")
|
||||
|
||||
if DEBUG >= 2:
|
||||
print("\n=== Basic structure trace ===")
|
||||
print_trace(packets)
|
||||
|
||||
def test_valu_instructions(self):
|
||||
"""Verify VALU instructions produce INST or VALUINST packets."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_mov_b32_e32(v[1], 2.0),
|
||||
v_add_f32_e32(v[2], v[0], v[1]),
|
||||
v_add_f32_e32(v[3], v[2], v[1]),
|
||||
v_mul_f32_e32(v[4], v[2], v[3]),
|
||||
]
|
||||
blobs = run_asm_sqtt(instructions)
|
||||
|
||||
self.assertGreater(len(blobs), 0, "No SQTT data captured")
|
||||
packets = decode_all_blobs(blobs)
|
||||
|
||||
inst_packets = [p for p in packets if isinstance(p, (INST, VALUINST))]
|
||||
self.assertGreater(len(inst_packets), 0, "No INST/VALUINST packets for VALU instructions")
|
||||
|
||||
if DEBUG >= 2:
|
||||
print("\n=== VALU instructions trace ===")
|
||||
print_trace(packets)
|
||||
|
||||
def test_salu_instructions(self):
|
||||
"""Verify SALU instructions produce appropriate packets."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0),
|
||||
s_mov_b32(s[1], 1),
|
||||
s_add_u32(s[2], s[0], s[1]),
|
||||
s_add_u32(s[3], s[2], s[1]),
|
||||
s_nop(0),
|
||||
]
|
||||
blobs = run_asm_sqtt(instructions)
|
||||
|
||||
self.assertGreater(len(blobs), 0, "No SQTT data captured")
|
||||
packets = decode_all_blobs(blobs)
|
||||
|
||||
if DEBUG >= 2:
|
||||
print("\n=== SALU instructions trace ===")
|
||||
print_trace(packets)
|
||||
|
||||
def test_timing_increases(self):
|
||||
"""Verify time increases monotonically through packets within each blob."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_mov_b32_e32(v[1], 2.0),
|
||||
v_add_f32_e32(v[2], v[0], v[1]),
|
||||
v_mul_f32_e32(v[3], v[2], v[1]),
|
||||
]
|
||||
blobs = run_asm_sqtt(instructions)
|
||||
|
||||
self.assertGreater(len(blobs), 0, "No SQTT data captured")
|
||||
for blob in blobs:
|
||||
packets = decode(blob)
|
||||
prev_time = 0
|
||||
for p in packets:
|
||||
self.assertGreaterEqual(p._time, prev_time, f"Time decreased: {prev_time} -> {p._time}")
|
||||
prev_time = p._time
|
||||
|
||||
def test_wave_id_consistency(self):
|
||||
"""Verify wave IDs are consistent between WAVESTART/WAVEEND."""
|
||||
blobs = run_asm_sqtt([v_mov_b32_e32(v[0], 0)])
|
||||
|
||||
self.assertGreater(len(blobs), 0, "No SQTT data captured")
|
||||
packets = decode_all_blobs(blobs)
|
||||
|
||||
wavestarts = [p for p in packets if isinstance(p, WAVESTART)]
|
||||
waveends = [p for p in packets if isinstance(p, WAVEEND)]
|
||||
|
||||
if wavestarts and waveends:
|
||||
start_waves = {p.wave for p in wavestarts}
|
||||
end_waves = {p.wave for p in waveends}
|
||||
self.assertTrue(start_waves & end_waves, "No matching wave IDs between WAVESTART and WAVEEND")
|
||||
|
||||
def test_nop_sequence(self):
|
||||
"""Test a sequence of NOP instructions."""
|
||||
blobs = run_asm_sqtt([s_nop(0), s_nop(0), s_nop(0)])
|
||||
|
||||
self.assertGreater(len(blobs), 0, "No SQTT data captured")
|
||||
packets = decode_all_blobs(blobs)
|
||||
self.assertGreater(len(packets), 0, "No packets decoded")
|
||||
|
||||
if DEBUG >= 2:
|
||||
print("\n=== NOP sequence trace ===")
|
||||
print_trace(packets, filter_timing=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
35
extra/assembly/amd/test/test_sqtt_multiwave.py
Normal file
35
extra/assembly/amd/test/test_sqtt_multiwave.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
import os
|
||||
os.environ["SQTT"] = "1"
|
||||
os.environ["PROFILE"] = "1"
|
||||
os.environ["SQTT_LIMIT_SE"] = "2"
|
||||
os.environ["SQTT_SIMD_SEL"] = "0"
|
||||
os.environ["SQTT_TOKEN_EXCLUDE"] = "3784" # Exclude WAVERDY, REG, EVENT, UTILCTR, WAVEALLOC, PERF
|
||||
|
||||
import unittest
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.sqtt import decode
|
||||
from extra.assembly.amd.test.test_sqtt_hw import compile_asm_sqtt, run_prg_sqtt_batch, format_packet
|
||||
from extra.assembly.amd.test.test_sqtt_compare import filter_noise_packets
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.engine.realize import get_runner
|
||||
|
||||
class SQTTMultiwave(unittest.TestCase):
|
||||
def test_simple_multiwave(self):
|
||||
ins = [
|
||||
s_barrier(),
|
||||
v_mov_b32_e32(v[0], v[1]),
|
||||
s_nop(0),
|
||||
s_nop(100),
|
||||
s_endpgm(),
|
||||
]
|
||||
#prg = get_runner("AMD", UOp.sink())._prg
|
||||
prg = compile_asm_sqtt(ins, alu_only=True)
|
||||
print(prg)
|
||||
blobs = run_prg_sqtt_batch(prg, n_runs=1, n_lanes=32*16)
|
||||
for blob in blobs:
|
||||
packets = decode(blob)
|
||||
for p in filter_noise_packets(packets):
|
||||
print(f" {p._time:8d}: {format_packet(p)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
204
extra/assembly/amd/test/test_sqtt_ops.py
Normal file
204
extra/assembly/amd/test/test_sqtt_ops.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Tests validating SQTT packet definitions against the reference implementation.
|
||||
|
||||
Verifies that:
|
||||
1. Encoding patterns produce the correct STATE_TO_OPCODE table
|
||||
2. Packet sizes (derived from fields) match expected budget values
|
||||
3. Field extractions match attempt_sqtt_parse.py
|
||||
"""
|
||||
import unittest
|
||||
from extra.assembly.amd.sqtt import (
|
||||
VALUINST, VMEMEXEC, ALUEXEC, IMMEDIATE, IMMEDIATE_MASK, WAVERDY,
|
||||
WAVEEND, WAVESTART, PERF, TS_WAVE_STATE, EVENT, EVENT_BIG, REG, SNAPSHOT,
|
||||
TS_DELTA_OR_MARK, LAYOUT_HEADER, INST, UTILCTR, TS_DELTA_SHORT, NOP,
|
||||
TS_DELTA_S8_W3, TS_DELTA_S5_W2, TS_DELTA_S5_W3, WAVEALLOC,
|
||||
decode, encode, OPCODE_TO_CLASS, STATE_TO_OPCODE, PACKET_TYPES, BUDGET,
|
||||
AluSrc, MemSrc, InstOp
|
||||
)
|
||||
|
||||
# Reference table from rocprof trace decoder (attempt_sqtt_parse.py)
|
||||
REFERENCE_STATE_TABLE = bytes([
|
||||
0x10, 0x16, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x17, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
0x10, 0x07, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x19, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
0x10, 0x00, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x11, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
0x10, 0x12, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x15, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
0x10, 0x16, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x17, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
0x10, 0x07, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x19, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
0x10, 0x00, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x11, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
0x10, 0x13, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x15, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
])
|
||||
|
||||
# Reference opcode -> name mapping (old opcode values from rocprof)
|
||||
OLD_OPCODE_TO_NAME = {
|
||||
0x01: 'VALUINST', 0x02: 'VMEMEXEC', 0x03: 'ALUEXEC', 0x04: 'IMMEDIATE',
|
||||
0x05: 'IMMEDIATE_MASK', 0x06: 'WAVERDY', 0x07: 'TS_DELTA_S8_W3',
|
||||
0x08: 'WAVEEND', 0x09: 'WAVESTART', 0x0A: 'TS_DELTA_S5_W2',
|
||||
0x0B: 'WAVEALLOC', 0x0C: 'TS_DELTA_S5_W3', 0x0D: 'PERF',
|
||||
0x0F: 'TS_DELTA_SHORT', 0x10: 'NOP', 0x11: 'TS_WAVE_STATE',
|
||||
0x12: 'EVENT', 0x13: 'EVENT_BIG', 0x14: 'REG', 0x15: 'SNAPSHOT',
|
||||
0x16: 'TS_DELTA_OR_MARK', 0x17: 'LAYOUT_HEADER', 0x18: 'INST',
|
||||
0x19: 'UTILCTR', 0x00: 'NOP',
|
||||
}
|
||||
|
||||
# Reference budget values (nibbles for NEXT packet) from rocprof
|
||||
REFERENCE_BUDGET_NIBBLES = {
|
||||
'VALUINST': 3, 'VMEMEXEC': 2, 'ALUEXEC': 2, 'IMMEDIATE': 3,
|
||||
'IMMEDIATE_MASK': 6, 'WAVERDY': 6, 'TS_DELTA_S8_W3': 16,
|
||||
'WAVEEND': 5, 'WAVESTART': 8, 'TS_DELTA_S5_W2': 12,
|
||||
'WAVEALLOC': 5, 'TS_DELTA_S5_W3': 13, 'PERF': 7,
|
||||
'TS_DELTA_SHORT': 2, 'NOP': 1, 'TS_WAVE_STATE': 6,
|
||||
'EVENT': 6, 'EVENT_BIG': 8, 'REG': 16, 'SNAPSHOT': 16,
|
||||
'TS_DELTA_OR_MARK': 12, 'LAYOUT_HEADER': 16, 'INST': 5,
|
||||
'UTILCTR': 12,
|
||||
}
|
||||
|
||||
|
||||
class TestEncodingsMatchStateTable(unittest.TestCase):
|
||||
"""Verify encoding patterns produce the correct state decode table."""
|
||||
|
||||
def test_all_256_bytes_decode_correctly(self):
|
||||
"""Each byte value should decode to the same packet type as reference."""
|
||||
mismatches = []
|
||||
for byte_val in range(256):
|
||||
ref_opcode = REFERENCE_STATE_TABLE[byte_val]
|
||||
ref_name = OLD_OPCODE_TO_NAME.get(ref_opcode, f"UNK_{ref_opcode:02x}")
|
||||
|
||||
our_opcode = STATE_TO_OPCODE[byte_val]
|
||||
our_name = OPCODE_TO_CLASS[our_opcode].__name__
|
||||
|
||||
if ref_name != our_name:
|
||||
mismatches.append((byte_val, ref_name, our_name))
|
||||
|
||||
if mismatches:
|
||||
msg = "\n".join(f" 0x{b:02x}: expected {r}, got {o}" for b, r, o in mismatches[:10])
|
||||
self.fail(f"State table mismatches ({len(mismatches)} total):\n{msg}")
|
||||
|
||||
|
||||
class TestPacketSizesMatchBudget(unittest.TestCase):
|
||||
"""Verify packet sizes (from field definitions) match expected budget values."""
|
||||
|
||||
def test_all_packet_sizes(self):
|
||||
"""Each packet type's size should match the reference budget."""
|
||||
for pkt_cls in PACKET_TYPES:
|
||||
name = pkt_cls.__name__
|
||||
expected = REFERENCE_BUDGET_NIBBLES.get(name)
|
||||
if expected is None:
|
||||
continue
|
||||
|
||||
actual = pkt_cls.size_nibbles()
|
||||
self.assertEqual(expected, actual,
|
||||
f"{name}: expected {expected} nibbles, got {actual} (size_bits={pkt_cls.size_bits()})")
|
||||
|
||||
|
||||
class TestFieldExtraction(unittest.TestCase):
|
||||
"""Test that field values are extracted correctly."""
|
||||
|
||||
def test_valuinst(self):
|
||||
reg = 0b11110_1_001_011 # wave=0x1E, flag=1, delta=1
|
||||
pkt = VALUINST.from_raw(reg)
|
||||
self.assertEqual(pkt.delta, 1)
|
||||
self.assertEqual(pkt.flag, 1)
|
||||
self.assertEqual(pkt.wave, 0x1E)
|
||||
|
||||
def test_vmemexec_enum(self):
|
||||
reg = 0b11_00_1111 # src=3 (VMEM_ALT), delta=0
|
||||
pkt = VMEMEXEC.from_raw(reg)
|
||||
self.assertEqual(pkt.src, MemSrc.VMEM_ALT)
|
||||
|
||||
def test_aluexec_enum(self):
|
||||
reg = 0b10_01_1110 # src=2 (VALU), delta=1
|
||||
pkt = ALUEXEC.from_raw(reg)
|
||||
self.assertEqual(pkt.src, AluSrc.VALU)
|
||||
|
||||
def test_waveend(self):
|
||||
reg = (0x15 << 15) | (0x7 << 11) | (0x3 << 9) | (1 << 8) | 0b10101
|
||||
pkt = WAVEEND.from_raw(reg)
|
||||
self.assertEqual(pkt.flag7, 1)
|
||||
self.assertEqual(pkt.simd, 3)
|
||||
self.assertEqual(pkt.cu_lo, 7)
|
||||
self.assertEqual(pkt.wave, 0x15)
|
||||
self.assertEqual(pkt.cu, 0xF) # cu_lo | (flag7 << 3) = 7 | 8 = 15
|
||||
|
||||
def test_wavestart(self):
|
||||
reg = (0x7F << 18) | (0x15 << 13) | (0x7 << 10) | (0x3 << 8) | (1 << 7) | 0b01100
|
||||
pkt = WAVESTART.from_raw(reg)
|
||||
self.assertEqual(pkt.flag7, 1)
|
||||
self.assertEqual(pkt.simd, 3)
|
||||
self.assertEqual(pkt.cu_lo, 7)
|
||||
self.assertEqual(pkt.wave, 0x15)
|
||||
self.assertEqual(pkt.id7, 0x7F)
|
||||
self.assertEqual(pkt.cu, 0xF)
|
||||
|
||||
def test_inst_enum(self):
|
||||
reg = (0x21 << 13) | (0x15 << 8) | (1 << 7) | (1 << 3) | 0b010
|
||||
pkt = INST.from_raw(reg)
|
||||
self.assertEqual(pkt.flag1, 1)
|
||||
self.assertEqual(pkt.flag2, 1)
|
||||
self.assertEqual(pkt.wave, 0x15)
|
||||
self.assertEqual(pkt.op, InstOp.VMEM_LOAD)
|
||||
|
||||
def test_layout_header(self):
|
||||
reg = (0b101 << 33) | (0b1010 << 28) | (0b111 << 15) | (0b11 << 13) | (0b101010 << 7) | 0b0010001
|
||||
pkt = LAYOUT_HEADER.from_raw(reg)
|
||||
self.assertEqual(pkt.layout, 0b101010)
|
||||
self.assertEqual(pkt.simd, 0b11)
|
||||
self.assertEqual(pkt.group, 0b111)
|
||||
self.assertEqual(pkt.sel_a, 0b1010)
|
||||
self.assertEqual(pkt.sel_b, 0b101)
|
||||
|
||||
def test_ts_delta_or_mark_modes(self):
|
||||
# delta mode: bit9=0, bit8=0
|
||||
pkt_delta = TS_DELTA_OR_MARK.from_raw(0b0000001) # just the encoding pattern
|
||||
self.assertFalse(pkt_delta.is_marker)
|
||||
|
||||
# marker mode: bit9=1, bit8=0
|
||||
pkt_marker = TS_DELTA_OR_MARK.from_raw(0b0000001 | (1 << 9)) # bit9=1, bit8=0
|
||||
self.assertTrue(pkt_marker.is_marker)
|
||||
|
||||
# other mode: bit9=1, bit8=1 (not marker)
|
||||
pkt_other = TS_DELTA_OR_MARK.from_raw(0b0000001 | (1 << 8) | (1 << 9))
|
||||
self.assertFalse(pkt_other.is_marker)
|
||||
|
||||
def test_reg(self):
|
||||
# REG fields: slot=bits[9:7], hi_byte=bits[15:8], subop=bits[31:16], val32=bits[63:32]
|
||||
# Note: slot[2:1] overlaps with hi_byte[1:0], so we need to set them consistently
|
||||
# hi_byte=0x55 means bits 8-15 = 0b01010101, so slot bits 8-9 = 0b01
|
||||
# slot bit 7 = 1, so slot = 0b011 = 3
|
||||
reg = (0xDEADBEEF << 32) | (0xCAFE << 16) | (0x55 << 8) | (1 << 7) | 0b1001
|
||||
pkt = REG.from_raw(reg)
|
||||
self.assertEqual(pkt.slot, 0b011) # bit7=1, bits 8-9 from hi_byte low 2 bits = 01
|
||||
self.assertEqual(pkt.hi_byte, 0x55)
|
||||
self.assertEqual(pkt.subop, 0xCAFE)
|
||||
self.assertEqual(pkt.val32, 0xDEADBEEF)
|
||||
|
||||
|
||||
class TestRoundtrip(unittest.TestCase):
|
||||
"""Test encode/decode roundtrip."""
|
||||
|
||||
def test_simple_roundtrip(self):
|
||||
"""Test encode/decode roundtrip preserves packet types."""
|
||||
test_packets = [
|
||||
LAYOUT_HEADER.from_raw(0x100),
|
||||
WAVESTART.from_raw(0x0),
|
||||
INST.from_raw(0x10),
|
||||
INST.from_raw(0x10),
|
||||
WAVEEND.from_raw(0x40),
|
||||
]
|
||||
encoded = encode(test_packets)
|
||||
decoded = decode(encoded)
|
||||
|
||||
self.assertGreaterEqual(len(decoded), len(test_packets))
|
||||
for i, (orig, dec) in enumerate(zip(test_packets, decoded)):
|
||||
self.assertEqual(type(orig), type(dec), f"type mismatch at {i}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -21,7 +21,8 @@ from tinygrad.runtime.support.memory import AddrSpace
|
|||
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
SQTT = ContextVar("SQTT", abs(VIZ.value)>=2)
|
||||
SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE = ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0)
|
||||
SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE, SQTT_SIMD_SEL, SQTT_TOKEN_EXCLUDE = \
|
||||
ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0), ContextVar("SQTT_SIMD_SEL", 0), ContextVar("SQTT_TOKEN_EXCLUDE", 0)
|
||||
PMC = ContextVar("PMC", abs(VIZ.value)>=2)
|
||||
EVENT_INDEX_PARTIAL_FLUSH = 4 # based on a comment in nvd.h
|
||||
WAIT_REG_MEM_FUNCTION_EQ = 3 # ==
|
||||
|
|
@ -252,17 +253,18 @@ class AMDComputeQueue(HWQueue):
|
|||
else:
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_SIZE, base_hi=buf0_hi, size=buf0s[se].size >> 12)
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_BASE, base_lo=buf0_lo)
|
||||
# NOTE: SQTT can only trace instructions on one simd per se, this selects first simd in first wgp in first sa.
|
||||
# NOTE: SQTT can only trace instructions on one simd per se, this selects the simd in first wgp in first sa.
|
||||
# For RGP to display instruction trace it has to see it on first SE. Howerver ACE/MEC/whatever does the dispatching starting with second se,
|
||||
# and on amdgpu/non-AM it also does weird things with dispatch order inside se: around 7 times out of 10 it starts from the last cu, but
|
||||
# sometimes not, especially if the kernel has more than one wavefront which means that kernels with small global size might get unlucky and
|
||||
# be dispatched on something else and not be seen in instruction tracing tab. You can force the wavefronts of a kernel to be dispatched on the
|
||||
# CUs you want to by disabling other CUs via bits in regCOMPUTE_STATIC_THREAD_MGMT_SE<x> and trace even kernels that only have one wavefront.
|
||||
# Use SQTT_SIMD_SEL to select which SIMD to trace (0-3). Memory ops show different InstOp values (0x2x vs 0x5x) based on SIMD.
|
||||
cs_wtype = (1 << 6) if self.dev.target >= (12,0,0) else self.soc.SQ_TT_WTYPE_INCLUDE_CS_BIT
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=0, wgp_sel=0, sa_sel=0)
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=SQTT_SIMD_SEL.value, wgp_sel=0, sa_sel=0)
|
||||
reg_include = self.soc.SQ_TT_TOKEN_MASK_SQDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_SHDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_GFXUDEC_BIT | \
|
||||
self.soc.SQ_TT_TOKEN_MASK_COMP_BIT | self.soc.SQ_TT_TOKEN_MASK_CONTEXT_BIT
|
||||
token_exclude = (1 << self.soc.SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT) if self.dev.target < (12,0,0) else 0
|
||||
token_exclude = SQTT_TOKEN_EXCLUDE.value | ((1 << self.soc.SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT) if self.dev.target < (12,0,0) else 0)
|
||||
|
||||
# disable instr tracing
|
||||
if not (SQTT_ITRACE_SE_MASK.value >> se) & 0b1:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue