Compare commits

...

8 commits

Author SHA1 Message Date
George Hotz
103a00d4c5 Merge origin/master 2026-01-01 17:15:45 +00:00
George Hotz
8c14d9f427 rdna4 2026-01-01 17:14:52 +00:00
George Hotz
4e03b3ebef rdna4 work 2026-01-01 16:45:39 +00:00
George Hotz
4571979fac refactor 2025-12-31 21:33:37 +00:00
George Hotz
9302f38f5b rdna4 works 2025-12-31 21:20:47 +00:00
George Hotz
2a6904029b more rdna4 work 2025-12-31 18:07:42 +00:00
George Hotz
14bc1b0c68 Merge origin/master 2025-12-31 17:47:59 +00:00
George Hotz
c9b074639e work on rdna4 asm 2025-12-31 16:18:19 +00:00
6 changed files with 1204 additions and 389 deletions

View file

@ -1,12 +1,15 @@
# RDNA3 assembler and disassembler
from __future__ import annotations
import re
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory, SRC_FIELDS, unwrap
from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, FLOAT_DEC, FLOAT_ENC, decode_src
from extra.assembly.amd.autogen.rdna3 import ins
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP,
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp)
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP, LDSDIR,
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, VINTERPOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp)
# VOP3SD opcodes that share VOP3 encoding
VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
def _matches_encoding(word: int, cls: type[Inst]) -> bool:
"""Check if word matches the encoding pattern of an instruction class."""
@ -26,7 +29,7 @@ def detect_format(data: bytes) -> type[Inst]:
if (word >> 30) == 0b11:
for cls in _FORMATS_64:
if _matches_encoding(word, cls):
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in Inst._VOP3SD_OPS else cls
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in VOP3SD_OPS else cls
raise ValueError(f"unknown 64-bit format word={word:#010x}")
# 32-bit formats
for cls in _FORMATS_32:
@ -37,10 +40,15 @@ def detect_format(data: bytes) -> type[Inst]:
# CONSTANTS
# ═══════════════════════════════════════════════════════════════════════════════
# GFX11 HWREG IDs
HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_HW_ID', 5: 'HW_REG_GPR_ALLOC',
6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS', 15: 'HW_REG_SH_MEM_BASES', 18: 'HW_REG_PERF_SNAPSHOT_PC_LO',
19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK',
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'}
# GFX12 HWREG IDs - use names that LLVM recognizes
HWREG_GFX12 = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 5: 'HW_REG_GPR_ALLOC', 6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS',
18: 'HW_REG_EXCP_FLAG_USER', 19: 'HW_REG_TRAP_CTRL', 20: 'HW_REG_SCRATCH_BASE_LO', 21: 'HW_REG_SCRATCH_BASE_HI',
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 29: 'HW_REG_SHADER_CYCLES_LO', 30: 'HW_REG_SHADER_CYCLES_HI'}
HWREG_IDS = {v.lower(): k for k, v in HWREG.items()}
MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA',
131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'}
@ -76,6 +84,8 @@ def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int:
return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
def _has(op: str, *subs) -> bool: return any(s in op for s in subs)
def _is16(op: str) -> bool: return _has(op, 'f16', 'i16', 'u16', 'b16') and not _has(op, '_f32', '_i32')
def _is64(op: str) -> bool: return _has(op, 'f64', 'i64', 'u64', 'b64')
def _omod(v: int) -> str: return {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(v, "")
def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v) # format 16-bit src: vgpr.h/l or literal
def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c)
@ -100,43 +110,72 @@ def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
# DISASSEMBLER
# ═══════════════════════════════════════════════════════════════════════════════
_VOP1_F64 = {VOP1Op.V_CEIL_F64, VOP1Op.V_FLOOR_F64, VOP1Op.V_FRACT_F64, VOP1Op.V_FREXP_MANT_F64, VOP1Op.V_RCP_F64, VOP1Op.V_RNDNE_F64, VOP1Op.V_RSQ_F64, VOP1Op.V_SQRT_F64, VOP1Op.V_TRUNC_F64}
def _disasm_vop1(inst: VOP1) -> str:
name = inst.op_name.lower()
if inst.op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name
if inst.op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}"
# 16-bit dst: uses .h/.l suffix (determined by name pattern, not dtype - e.g. sat_pk_u8_i16 outputs 8-bit but uses 16-bit encoding)
# Use architecture-specific op enum
if 'rdna4' in inst.__class__.__module__:
from extra.assembly.amd.autogen.rdna4.enum import VOP1Op as OpEnum
else:
OpEnum = VOP1Op
op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower()
if name in ('v_nop', 'v_pipeflush'): return name
if name == 'v_readfirstlane_b32': return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}"
parts = name.split('_')
is_f64_d = 'f64' in name and any(x in name for x in ['ceil', 'floor', 'fract', 'frexp_mant', 'rcp', 'rndne', 'rsq', 'sqrt', 'trunc', 'cvt_f64_f32', 'cvt_f64_i32', 'cvt_f64_u32'])
is_f64_s = 'f64' in name and any(x in name for x in ['ceil', 'floor', 'fract', 'frexp_mant', 'rcp', 'rndne', 'rsq', 'sqrt', 'trunc', 'cvt_f32_f64', 'cvt_i32_f64', 'cvt_u32_f64', 'frexp_exp_i32_f64'])
# v_cvt_pk_f32_bf8/fp8 output 2 VGPRs (packed f32x2) and take packed 8-bit (16-bit VGPR with .l/.h) source
is_pk_f32 = 'cvt_pk_f32_bf8' in name or 'cvt_pk_f32_fp8' in name
is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
src = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
# Only packed bf8/fp8 (cvt_pk_*) use 16-bit VGPR encoding; non-packed versions use regular VGPRs
is_16s = (parts[-1] in ('f16','i16','u16','b16') and 'sat_pk' not in name) or (parts[-1] in ('bf8', 'fp8') and 'pk' in name)
dst = _vreg(inst.vdst, 2) if is_f64_d or is_pk_f32 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
src = _fmt_src(inst.src0, 2) if is_f64_s else _src16(inst, inst.src0) if is_16s else inst.lit(inst.src0)
return f"{name}_e32 {dst}, {src}"
def _disasm_vop2(inst: VOP2) -> str:
name = inst.op_name.lower()
suf = "" if inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32"
# Use architecture-specific op enum
if 'rdna4' in inst.__class__.__module__:
from extra.assembly.amd.autogen.rdna4.enum import VOP2Op as OpEnum
else:
OpEnum = VOP2Op
op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower()
suf, is16 = "" if name == 'v_dot2acc_f32_f16' else "_e32", _is16(name) and 'pk_' not in name
is64 = _is64(name)
# For shift ops with b64, src0 is 32-bit (shift amount), dst/vsrc1 are 64-bit
is_shift64 = 'lshlrev_b64' in name
# fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1
if inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}"
if inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}"
if inst.is_16bit(): return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if inst.op == VOP2Op.V_CNDMASK_B32 else "")
if 'fmaak' in name: return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}"
if 'fmamk' in name: return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}"
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
if is_shift64: return f"{name}{suf} {_vreg(inst.vdst, 2)}, {inst.lit(inst.src0)}, {_vreg(inst.vsrc1, 2)}"
if is64: return f"{name}{suf} {_vreg(inst.vdst, 2)}, {_fmt_src(inst.src0, 2)}, {_vreg(inst.vsrc1, 2)}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if name == 'v_cndmask_b32' else "")
def _disasm_vopc(inst: VOPC) -> str:
name = inst.op_name.lower()
s0 = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}"
return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
# Use architecture-specific op enum
if 'rdna4' in inst.__class__.__module__:
from extra.assembly.amd.autogen.rdna4.enum import VOPCOp as OpEnum
else:
OpEnum = VOPCOp
op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower()
is64, is16 = _is64(name), _is16(name)
is_class = 'class' in name
s0 = _fmt_src(inst.src0, 2) if is64 else _src16(inst, inst.src0) if is16 else inst.lit(inst.src0)
s1 = _vreg(inst.vsrc1, 2) if is64 and not is_class else _fmt_v16(inst.vsrc1, 0, 128) if is16 else f"v{inst.vsrc1}"
return f"{name}_e32 {s0}, {s1}" if op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
NO_ARG_SOPP = {SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE}
def _disasm_sopp(inst: SOPP) -> str:
name = inst.op_name.lower()
if inst.op in NO_ARG_SOPP: return name
if inst.op == SOPPOp.S_WAITCNT:
op, name = SOPPOp(inst.op), SOPPOp(inst.op).name.lower()
if op in NO_ARG_SOPP: return name
if op == SOPPOp.S_WAITCNT:
vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f
p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
if inst.op == SOPPOp.S_DELAY_ALU:
if op == SOPPOp.S_DELAY_ALU:
deps, skips = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3'], ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4']
id0, skip, id1 = inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x7, (inst.simm16 >> 7) & 0xf
dep = lambda v: deps[v-1] if 0 < v <= len(deps) else str(v)
@ -145,20 +184,51 @@ def _disasm_sopp(inst: SOPP) -> str:
return f"{name} {inst.simm16}" if name.startswith(('s_cbranch', 's_branch')) else f"{name} 0x{inst.simm16:x}"
def _disasm_smem(inst: SMEM) -> str:
name = inst.op_name.lower()
if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset)
sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op.value <= 12 or name == 's_atc_probe_buffer') else 2
sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}"
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
is_rdna4 = 'rdna4' in inst.__class__.__module__
if is_rdna4:
from extra.assembly.amd.autogen.rdna4.enum import SMEMOp as SMEMOp4
op = SMEMOp4(inst.op)
name = op.name.lower()
if op == SMEMOp4.S_DCACHE_INV: return name
# RDNA4: s_buffer_* uses 4-SGPR descriptor, s_load/s_prefetch uses 2-SGPR
is_buffer = 'buffer' in name
sbase_idx = inst.sbase * 2
sbase_count = 4 if is_buffer else 2
sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
# Format offset - ioffset is signed 24-bit, show as hex
ioff = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000 # sign extend
off_hex = f"0x{ioff & 0xffffff:x}" if ioff >= 0 else f"-0x{(-ioff) & 0xffffff:x}"
off_s = f"{decode_src(inst.soffset)} offset:{off_hex}" if inst.soffset != 124 else off_hex
# Data width from opcode
width_map = {0:1, 1:2, 2:4, 3:8, 4:16, 5:3, 8:1, 9:1, 10:1, 11:1, 16:1, 17:2, 18:4, 19:8, 20:16, 21:3, 24:1, 25:1, 26:1, 27:1}
width = width_map.get(inst.op, 1)
if 'prefetch' in name:
# Prefetch has different format: s_prefetch_* sbase, offset, soffset, length
# But we need to handle various prefetch types differently
if name == 's_prefetch_inst_pc_rel' or name == 's_prefetch_data_pc_rel':
return f"{name} {off_hex}, {decode_src(inst.soffset)}, {inst.sdata}"
return f"{name} {sbase_str}, {off_hex}, {decode_src(inst.soffset)}, {inst.sdata}"
return f"{name} {_fmt_sdst(inst.sdata, width)}, {sbase_str}, {off_s}"
else:
op = SMEMOp(inst.op)
name = op.name.lower()
if op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset)
sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op <= 12 or name == 's_atc_probe_buffer') else 2
sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}"
width = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(inst.op, 1)
return f"{name} {_fmt_sdst(inst.sdata, width)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
def _disasm_flat(inst: FLAT) -> str:
name = inst.op_name.lower()
name = FLATOp(inst.op).name.lower()
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
w = inst.dst_regs() * (2 if 'cmpswap' in name else 1)
suffix = name.split('_')[-1]
w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'f32':1,'f64':2}.get(suffix, 1)
if 'cmpswap' in name: w *= 2
if name.endswith('_x2') or 'x2' in suffix: w = max(w, 2)
mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
# saddr
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
@ -178,11 +248,11 @@ def _disasm_flat(inst: FLAT) -> str:
return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
def _disasm_ds(inst: DS) -> str:
op, name = inst.op, inst.op_name.lower()
op, name = DSOp(inst.op), DSOp(inst.op).name.lower()
gds = " gds" if inst.gds else ""
off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else ""
off2 = f" offset0:{inst.offset0} offset1:{inst.offset1}" if inst.offset0 or inst.offset1 else ""
w = inst.dst_regs()
w = 4 if '128' in name else 3 if '96' in name else 2 if (name.endswith('64') or 'gs_reg' in name) else 1
d0, d1, dst, addr = _vreg(inst.data0, w), _vreg(inst.data1, w), _vreg(inst.vdst, w), f"v{inst.addr}"
if op == DSOp.DS_NOP: return name
@ -206,34 +276,56 @@ def _disasm_ds(inst: DS) -> str:
return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}"
def _disasm_vop3(inst: VOP3) -> str:
op, name = inst.op, inst.op_name.lower()
is_rdna4 = 'rdna4' in inst.__class__.__module__
if is_rdna4:
from extra.assembly.amd.autogen.rdna4.enum import VOP3Op as VOP3Op4, VOP3SDOp as VOP3SDOp4
op = VOP3SDOp4(inst.op) if inst.op in VOP3SD_OPS else VOP3Op4(inst.op)
else:
op = VOP3SDOp(inst.op) if inst.op in VOP3SD_OPS else VOP3Op(inst.op)
name = op.name.lower()
# VOP3SD (shared encoding)
if isinstance(op, VOP3SDOp):
if inst.op in VOP3SD_OPS:
sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {srcs}" + _omod(inst.omod)
is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32', 'mad_co_i64_i32', 'mad_co_u64_u32')
def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s
s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64)
dst = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}"
if op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}"
if op in (VOP3SDOp.V_ADD_CO_CI_U32, VOP3SDOp.V_SUB_CO_CI_U32, VOP3SDOp.V_SUBREV_CO_CI_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}"
return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}" + _omod(inst.omod)
# Detect 16-bit operand sizes (for .h/.l suffix handling)
# Detect operand sizes
is64 = _is64(name)
is64_src, is64_dst = False, False
is16_d = is16_s = is16_s2 = False
if 'cvt_pk' in name: is16_s = name.endswith('16')
# v_cvt_pk_f32_bf8/fp8 outputs a VGPR pair (f32x2) from 16-bit packed input
if 'cvt_pk_f32_bf8' in name or 'cvt_pk_f32_fp8' in name: is64_dst = True
elif 'cvt_pk' in name: is16_s = name.endswith('16')
elif m := re.match(r'v_(?:cvt|frexp_exp)_([a-z0-9_]+)_([a-z0-9]+)', name):
is16_d, is16_s = _has(m.group(1), 'f16','i16','u16','b16'), _has(m.group(2), 'f16','i16','u16','b16')
is16_s2 = is16_s
is64_src, is64_dst = '64' in m.group(2), '64' in m.group(1)
is16_s2, is64 = is16_s, False
elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True
elif 'pack_b32' in name: is16_s = is16_s2 = True
else: is16_d = is16_s = is16_s2 = inst.is_16bit()
else: is16_d = is16_s = is16_s2 = _is16(name) and not _has(name, 'dot2', 'pk_', 'sad', 'msad', 'qsad', 'mqsad')
# Source counts
shift64 = 'rev' in name and '64' in name and name.startswith('v_')
ldexp64 = op == VOP3Op.V_LDEXP_F64
trig = op == VOP3Op.V_TRIG_PREOP_F64
sad64, mqsad = _has(name, 'qsad_pk', 'mqsad_pk'), 'mqsad_u32' in name
s0n = 2 if ((is64 and not shift64) or sad64 or mqsad or is64_src) else 1
s1n = 2 if (is64 and not _has(name, 'class') and not ldexp64 and not trig) else 1
s2n = 4 if mqsad else 2 if (is64 or sad64) else 1
any_hi = inst.opsel != 0
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s, any_hi)
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s, any_hi)
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2, any_hi)
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, s0n, is16_s, any_hi)
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, s1n, is16_s, any_hi)
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, s2n, is16_s2, any_hi)
# Destination
dn = inst.dst_regs()
dn = 4 if mqsad else 2 if (is64 or sad64 or is64_dst) else 1
if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1)
elif dn > 1: dst = _vreg(inst.vdst, dn)
elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" if any_hi else f"v{inst.vdst}"
@ -243,54 +335,126 @@ def _disasm_vop3(inst: VOP3) -> str:
nonvgpr_opsel = (inst.src0 < 256 and (inst.opsel & 1)) or (inst.src1 < 256 and (inst.opsel & 2)) or (inst.src2 < 256 and (inst.opsel & 4))
need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s)
# RDNA4 v_s_* instructions (pseudo-scalar VOP1-like) have SGPR destination
if name.startswith('v_s_') and is_rdna4:
return f"{name} {_fmt_sdst(inst.vdst, 1)}, {s0}{cl}{om}"
if inst.op < 256: # VOPC
return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
if inst.op < 384: # VOP2
n = inst.num_srcs()
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}"
os = _opsel_str(inst.opsel, 3, need_opsel, is16_d) if 'cndmask' in name else _opsel_str(inst.opsel, 2, need_opsel, is16_d)
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if 'cndmask' in name else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}"
if inst.op < 512: # VOP1
return f"{name}_e64" if op in (VOP3Op.V_NOP, VOP3Op.V_PIPEFLUSH) else f"{name}_e64 {dst}, {s0}{_opsel_str(inst.opsel, 1, need_opsel, is16_d)}{cl}{om}"
if name in ('v_nop', 'v_pipeflush'): return f"{name}_e64"
# Handle byte_sel for non-pk fp8/bf8 conversions
if ('cvt_f32_fp8' in name or 'cvt_f32_bf8' in name) and 'pk' not in name:
byte_sel = inst.opsel & 3
os = f" byte_sel:{byte_sel}" if byte_sel else ""
elif 'cvt_pk_f32_bf8' in name or 'cvt_pk_f32_fp8' in name:
os = _opsel_str(inst.opsel, 2, need_opsel, is16_d) # 2-element for pk variants
else:
os = _opsel_str(inst.opsel, 1, need_opsel, is16_d) # 1-element for other VOP1
return f"{name}_e64 {dst}, {s0}{os}{cl}{om}"
# Native VOP3
n = inst.num_srcs()
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}"
is3 = _has(name, 'fma', 'mad', 'min3', 'max3', 'med3', 'div_fix', 'div_fmas', 'sad', 'lerp', 'align', 'cube', 'bfe', 'bfi',
'perm_b32', 'cndmask', 'xor3', 'or3', 'add3', 'lshl_or', 'and_or', 'lshl_add', 'add_lshl', 'xad', 'maxmin', 'minmax', 'dot2', 'cvt_pk_u8', 'mullit',
'minimummaximum', 'maximumminimum', 'minimum3', 'maximum3')
# permlane16/permlanex16 have 3 sources, but _var variants have 2
if 'permlane' in name and 'var' not in name: is3 = True
# Handle byte_sel for fp8/bf8 instructions (opsel encodes byte_sel, not op_sel)
# For VOP1-encoded VOP3 (op < 512): cvt_f32_fp8, cvt_f32_bf8
# For native VOP3: cvt_sr_fp8, cvt_sr_bf8
if ('cvt_f32_fp8' in name or 'cvt_f32_bf8' in name or 'cvt_sr_fp8' in name or 'cvt_sr_bf8' in name) and 'pk' not in name:
# For VOP1 encoding (op < 512), byte_sel is in bits[1:0] of opsel; for native VOP3, it's bits[3:2]
byte_sel = (inst.opsel & 3) if inst.op < 512 else ((inst.opsel >> 2) & 3)
os = f" byte_sel:{byte_sel}" if byte_sel else ""
else:
os = _opsel_str(inst.opsel, 3 if is3 else 2, need_opsel, is16_d)
return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if is3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}"
def _disasm_vop3sd(inst: VOP3SD) -> str:
name = inst.op_name.lower()
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
op, name = VOP3SDOp(inst.op), VOP3SDOp(inst.op).name.lower()
is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32')
def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s
s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64)
dst, is2src = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}", op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32)
suffix = "_e64" if name.startswith('v_') and 'co_' in name else ""
return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {srcs}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}"
return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{'' if is2src else f', {s2}'}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}"
def _disasm_vopd(inst: VOPD) -> str:
is_rdna4 = 'rdna4' in inst.__class__.__module__
if is_rdna4:
from extra.assembly.amd.autogen.rdna4.enum import VOPDOp as VOPDOp4
OpEnum = VOPDOp4
else:
OpEnum = VOPDOp
lit = inst._literal or inst.literal
vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), VOPDOp(inst.opx).name.lower(), VOPDOp(inst.opy).name.lower()
vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), OpEnum(inst.opx).name.lower(), OpEnum(inst.opy).name.lower()
def half(n, vd, s0, vs1): return f"{n} v{vd}, {inst.lit(s0)}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" if 'mov' in n else f"{n} v{vd}, {inst.lit(s0)}, v{vs1}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}"
return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}"
def _disasm_vop3p(inst: VOP3P) -> str:
name = inst.op_name.lower()
is_wmma, n, is_fma_mix = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name
if is_wmma:
sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8
src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8)
def _disasm_vop3p(inst: VOP3P, wave_size: int = 32) -> str:
is_rdna4 = 'rdna4' in inst.__class__.__module__
if is_rdna4:
from extra.assembly.amd.autogen.rdna4.enum import VOP3POp as OpEnum
else:
OpEnum = VOP3POp
name = OpEnum(inst.op).name.lower()
is_wmma, is_swmmac = 'wmma' in name and 'swmmac' not in name, 'swmmac' in name
is_3src, is_fma_mix = _has(name, 'fma', 'mad', 'dot', 'wmma'), 'fma_mix' in name
# Wave64 uses half the register widths of wave32 for WMMA
wave_div = 2 if wave_size == 64 else 1
if is_swmmac and is_rdna4:
# SWMMAC (sparse WMMA): src2 is a single VGPR index, not an accumulator
# Determine src0/src1/dst sizes based on instruction type
if 'f16' in name or 'bf16' in name:
if 'f16_16x16x32_f16' in name or 'bf16_16x16x32_bf16' in name:
s0c, s1c, dc = 4, 8, 4 # f16/bf16 output
else:
s0c, s1c, dc = 4, 8, 8 # f32 output
elif 'iu8' in name: s0c, s1c, dc = 2, 4, 8
elif 'iu4' in name:
if '16x16x64' in name: s0c, s1c, dc = 2, 4, 8
else: s0c, s1c, dc = 1, 2, 8
elif 'fp8' in name or 'bf8' in name: s0c, s1c, dc = 2, 4, 8
else: s0c, s1c, dc = 4, 8, 8
s0c, s1c, dc = max(1, s0c // wave_div), max(1, s1c // wave_div), max(1, dc // wave_div)
src0, src1, src2, dst = _fmt_src(inst.src0, s0c), _fmt_src(inst.src1, s1c), _fmt_src(inst.src2, 1), _vreg(inst.vdst, dc)
elif is_wmma:
# RDNA4 WMMA uses smaller source register widths than RDNA3
if is_rdna4:
# RDNA4 wave32 source widths: iu4->1/2, iu8->2, fp8/bf8->2, f16/bf16->4
if 'iu4' in name:
sc = 2 if '16x16x32' in name else 1
elif 'iu8' in name or 'fp8' in name or 'bf8' in name: sc = 2
else: sc = 4 # f16/bf16
# Destination width: f16/bf16 output->4, f32/i32 output->8
dc = 4 if name.startswith('v_wmma_f16') or name.startswith('v_wmma_bf16') else 8
else:
# RDNA3: iu4->2, iu8->4, f16/bf16->8
sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8
dc = 8
sc, dc = max(1, sc // wave_div), max(1, dc // wave_div)
src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, dc), _vreg(inst.vdst, dc)
else: src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}"
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
if is_fma_mix:
n, opsel_hi = 3 if is_3src else 2, inst.opsel_hi | (inst.opsel_hi2 << 2)
if is_swmmac and is_rdna4:
# SWMMAC uses index_key instead of op_sel; opsel bits encode the key value
mods = ([f"index_key:{inst.opsel & 7}"] if inst.opsel else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
elif is_fma_mix:
def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s)
src0, src1, src2 = m(src0, inst.neg & 1, inst.neg_hi & 1), m(src1, inst.neg & 2, inst.neg_hi & 2), m(src2, inst.neg & 4, inst.neg_hi & 4)
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if inst.clmp else [])
else:
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if is_3src else 3) else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if is_3src else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
def _disasm_buf(inst: MUBUF | MTBUF) -> str:
name = inst.op_name.lower()
if inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
op = MTBUFOp(inst.op) if isinstance(inst, MTBUF) else MUBUFOp(inst.op)
name = op.name.lower()
if op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \
{'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1)
@ -318,7 +482,7 @@ def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked
def _disasm_mimg(inst: MIMG) -> str:
name = inst.op_name.lower()
name = MIMGOp(inst.op).name.lower()
srsrc_base = inst.srsrc * 4
srsrc_str = _sreg_or_ttmp(srsrc_base, 8)
# BVH intersect ray: special case with 4 SGPR srsrc
@ -334,8 +498,8 @@ def _disasm_mimg(inst: MIMG) -> str:
dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"dim_{inst.dim}"
vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16)
vaddr_str = f"v{inst.vaddr}" if vaddr == 1 else _vreg(inst.vaddr, vaddr)
# modifiers
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask and (inst.dmask != 15 or 'atomic' in name) else []
# modifiers - always include dmask for image load/store/atomic (LLVM uses it for vdata size validation)
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask else []
mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}")
for flag, mod in [(inst.unrm,"unorm"),(inst.glc,"glc"),(inst.slc,"slc"),(inst.dlc,"dlc"),(inst.r128,"r128"),
(inst.a16,"a16"),(inst.tfe,"tfe"),(inst.lwe,"lwe"),(inst.d16,"d16")]:
@ -346,38 +510,329 @@ def _disasm_mimg(inst: MIMG) -> str:
ssamp_str = ", " + _sreg_or_ttmp(inst.ssamp * 4, 4)
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}"
def _sop_widths(name: str) -> tuple[int, int, int]:
"""Return (dst_width, src0_width, src1_width) in register count for SOP instructions."""
if name in ('s_bitset0_b64', 's_bitset1_b64', 's_bfm_b64'): return 2, 1, 1
if name in ('s_lshl_b64', 's_lshr_b64', 's_ashr_i64', 's_bfe_u64', 's_bfe_i64'): return 2, 2, 1
if name in ('s_bitcmp0_b64', 's_bitcmp1_b64'): return 1, 2, 1
if m := re.search(r'_(b|i|u)(32|64)_(b|i|u)(32|64)$', name): return 2 if m.group(2) == '64' else 1, 2 if m.group(4) == '64' else 1, 1
if m := re.search(r'_(b|i|u)(32|64)$', name): sz = 2 if m.group(2) == '64' else 1; return sz, sz, sz
return 1, 1, 1
def _disasm_sop1(inst: SOP1) -> str:
op, name = inst.op, inst.op_name.lower()
op, name = SOP1Op(inst.op), SOP1Op(inst.op).name.lower()
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {_fmt_src(inst.ssrc0, 2)}"
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {_fmt_src(inst.ssrc0, 2)}"
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.src_regs(0) == 1 else _fmt_src(inst.ssrc0, inst.src_regs(0))}"
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, 2 if 'b64' in name else 1)}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
dn, s0n, _ = _sop_widths(name)
return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if s0n == 1 else _fmt_src(inst.ssrc0, s0n)}"
def _disasm_sop2(inst: SOP2) -> str:
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}"
name = SOP2Op(inst.op).name.lower()
dn, s0n, s1n = _sop_widths(name)
return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, s0n)}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, s1n)}"
def _disasm_sopc(inst: SOPC) -> str:
return f"{inst.op_name.lower()} {_fmt_src(inst.ssrc0, inst.src_regs(0))}, {_fmt_src(inst.ssrc1, inst.src_regs(1))}"
name = SOPCOp(inst.op).name.lower()
_, s0n, s1n = _sop_widths(name)
return f"{name} {_fmt_src(inst.ssrc0, s0n)}, {_fmt_src(inst.ssrc1, s1n)}"
def _disasm_sopk(inst: SOPK) -> str:
op, name = inst.op, inst.op_name.lower()
if op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}"
if op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32):
# Use architecture-specific SOPK op enum
if 'rdna4' in inst.__class__.__module__:
from extra.assembly.amd.autogen.rdna4.enum import SOPKOp as OpEnum
hwreg_map = HWREG_GFX12
else:
OpEnum = SOPKOp
hwreg_map = HWREG
op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower()
if name == 's_version': return f"{name} 0x{inst.simm16:x}"
if name in ('s_setreg_b32', 's_getreg_b32'):
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if op == SOPKOp.S_SETREG_B32 else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}"
hreg_name = hwreg_map.get(hid, str(hid))
# If offset=0 and size=32, use short form hwreg(NAME), otherwise hwreg(NAME, off, sz)
if hid in (16, 17): hs = f"0x{inst.simm16:x}"
elif hoff == 0 and hsz == 32: hs = f"hwreg({hreg_name})"
else: hs = f"hwreg({hreg_name}, {hoff}, {hsz})"
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if name == 's_setreg_b32' else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
dn, _, _ = _sop_widths(name)
return f"{name} {_fmt_sdst(inst.sdst, dn)}, 0x{inst.simm16:x}"
def _disasm_vinterp(inst: VINTERP) -> str:
name = VINTERPOp(inst.op).name.lower()
src0 = f"-{inst.lit(inst.src0)}" if inst.neg & 1 else inst.lit(inst.src0)
src1 = f"-{inst.lit(inst.src1)}" if inst.neg & 2 else inst.lit(inst.src1)
src2 = f"-{inst.lit(inst.src2)}" if inst.neg & 4 else inst.lit(inst.src2)
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
return f"{inst.op_name.lower()} v{inst.vdst}, {inst.lit(inst.src0, inst.neg & 1)}, {inst.lit(inst.src1, inst.neg & 2)}, {inst.lit(inst.src2, inst.neg & 4)}" + (" " + mods if mods else "")
return f"{name} v{inst.vdst}, {src0}, {src1}, {src2}" + (" " + mods if mods else "")
# Export targets: mrt0-7, mrtz, pos0-4, prim, dual_src_blend0/1
_EXP_TARGETS = {**{i: f'mrt{i}' for i in range(8)}, 8: 'mrtz', **{i+12: f'pos{i}' for i in range(5)}, 20: 'prim', 21: 'dual_src_blend0', 22: 'dual_src_blend1'}
def _disasm_exp(inst) -> str:
target = _EXP_TARGETS.get(inst.target, f"invalid_target_{inst.target}")
en = inst.en
vsrc = lambda i, v: f"v{v}" if (en >> i) & 1 else "off"
srcs = f"{vsrc(0, inst.vsrc0)}, {vsrc(1, inst.vsrc1)}, {vsrc(2, inst.vsrc2)}, {vsrc(3, inst.vsrc3)}"
mods = _mods((inst.done, "done"), (inst.row, "row_en"))
prefix = "export" if 'rdna4' in inst.__class__.__module__ else "exp"
return f"{prefix} {target} {srcs}" + (" " + mods if mods else "")
def _disasm_ldsdir(inst) -> str:
is_rdna4 = 'rdna4' in inst.__class__.__module__
if is_rdna4:
# RDNA4 uses ds_* prefix and wait_va_vdst/wait_vm_vsrc modifiers
wait = f" wait_va_vdst:{inst.wait_va} wait_vm_vsrc:{inst.wait_vm}"
if inst.op == 1: return f"ds_direct_load v{inst.vdst}{wait}"
if inst.op == 0: return f"ds_param_load v{inst.vdst}, attr{inst.attr}.{['x','y','z','w'][inst.attr_chan]}{wait}"
else:
# RDNA3 uses lds_* prefix and wait_vdst modifier
wait = f" wait_vdst:{inst.wait_va}" if inst.wait_va != 0 else ""
if inst.op == 1: return f"lds_direct_load v{inst.vdst}{wait}"
if inst.op == 0: return f"lds_param_load v{inst.vdst}, attr{inst.attr}.{['x','y','z','w'][inst.attr_chan]}{wait}"
raise ValueError(f"unknown LDSDIR op: {inst.op}")
# ═══════════════════════════════════════════════════════════════════════════════
# RDNA4-specific disassemblers (GFX12)
# ═══════════════════════════════════════════════════════════════════════════════
# th values for RDNA4 memory instructions (based on AMDGPU ISA docs and LLVM SIDefines.h)
# Load: 0=RT(default), 1=NT, 2=HT, 3=LU, 4=NT_RT, 5=RT_NT, 6=NT_HT, 7=BYPASS(only with scope=SYS)
_TH_LOAD = {0: '', 1: 'th:TH_LOAD_NT', 2: 'th:TH_LOAD_HT', 3: 'th:TH_LOAD_LU', 4: 'th:TH_LOAD_NT_RT', 5: 'th:TH_LOAD_RT_NT', 6: 'th:TH_LOAD_NT_HT', 7: 'th:TH_LOAD_BYPASS'}
# Store: 0=RT(default), 1=NT, 2=HT, 3=WB, 4=NT_RT, 5=RT_NT, 6=NT_HT, 7=NT_WB (BYPASS is th=7 + scope=SYS)
_TH_STORE = {0: '', 1: 'th:TH_STORE_NT', 2: 'th:TH_STORE_HT', 3: 'th:TH_STORE_WB', 4: 'th:TH_STORE_NT_RT', 5: 'th:TH_STORE_RT_NT', 6: 'th:TH_STORE_NT_HT', 7: 'th:TH_STORE_NT_WB'}
# Atomic: bit0=RETURN, bit1=NT, bit2=CASCADE -> 0=none, 1=RETURN, 2=NT, 3=NT_RETURN, 4=CASCADE_RT, 5=RT_RETURN(N/A), 6=CASCADE_NT, 7=N/A
_TH_ATOMIC = {0: '', 1: 'th:TH_ATOMIC_RETURN', 2: 'th:TH_ATOMIC_NT', 3: 'th:TH_ATOMIC_NT_RETURN', 4: 'th:TH_ATOMIC_CASCADE_RT', 5: 'th:TH_ATOMIC_RT_RETURN', 6: 'th:TH_ATOMIC_CASCADE_NT', 7: 'th:TH_ATOMIC_CASCADE_NT'}
_SCOPE = {0: '', 1: 'scope:SCOPE_SE', 2: 'scope:SCOPE_DEV', 3: 'scope:SCOPE_SYS'}
def _rdna4_mem_mods(th: int, scope: int, is_store: bool, is_atomic: bool) -> str:
th_map = _TH_ATOMIC if is_atomic else _TH_STORE if is_store else _TH_LOAD
# Special case: th=3 with scope=SYS means BYPASS for load/store (otherwise th=3 means LU/WB)
if th == 3 and scope == 3 and not is_atomic:
th_s = 'th:TH_STORE_BYPASS' if is_store else 'th:TH_LOAD_BYPASS'
else:
th_s = th_map.get(th, f'th:{th}' if th else '')
scope_s = _SCOPE.get(scope, f'scope:{scope}' if scope else '')
return ' '.join(x for x in [th_s, scope_s] if x)
def _disasm_vflat(inst) -> str:
"""Disassemble RDNA4 VFLAT/VGLOBAL/VSCRATCH instructions."""
from extra.assembly.amd.autogen.rdna4.enum import VFLATOp, VGLOBALOp, VSCRATCHOp
cls_name = type(inst).__name__
if cls_name == 'VGLOBAL': op_enum, prefix = VGLOBALOp, 'global'
elif cls_name == 'VSCRATCH': op_enum, prefix = VSCRATCHOp, 'scratch'
else: op_enum, prefix = VFLATOp, 'flat'
name = op_enum(inst.op).name.lower()
# global_wb, global_wbinv, global_inv are cache control instructions with no operands
if name in ('global_wb', 'global_wbinv', 'global_inv'):
mods = _rdna4_mem_mods(inst.th, inst.scope, False, False)
return f"{name}" + (f" {mods}" if mods else "")
# addtid instructions use thread ID as address offset, no vaddr operand
is_addtid = 'addtid' in name
# Data width based on instruction name suffix
suffix = name.split('_')[-1]
# block loads/stores use 32 VGPRs
if 'block' in name:
base_w = 32
else:
base_w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'f32':1,'f64':2}.get(suffix, 1)
# For cmpswap: vsrc holds cmp+data pairs (2x base), vdst is base width
vsrc_w = base_w * 2 if 'cmpswap' in name else base_w
vdst_w = base_w
# Offset: signed 24-bit (stored as unsigned, needs sign extension)
off = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000
off_s = f" offset:{off}" if off else ""
# Memory modifiers
is_store, is_atomic = 'store' in name, 'atomic' in name
mods = _rdna4_mem_mods(inst.th, inst.scope, is_store, is_atomic)
# saddr handling - VGLOBAL and VSCRATCH need explicit "off" when saddr=124
if inst.saddr == 124: saddr_s = "off" if prefix in ('global', 'scratch') else ""
elif inst.saddr in SPECIAL_PAIRS: saddr_s = SPECIAL_PAIRS[inst.saddr]
else: saddr_s = _sreg(inst.saddr, 2) if prefix == 'global' else decode_src(inst.saddr)
# Address width: 1 for scratch with saddr, 2 otherwise
addr_w = 1 if (prefix == 'scratch' or (inst.saddr != 124 and prefix != 'flat')) else 2
vaddr_s = _vreg(inst.vaddr, addr_w)
vsrc_s = _vreg(inst.vsrc, vsrc_w)
vdst_s = _vreg(inst.vdst, vdst_w)
# addtid instructions don't have vaddr, just vdata and saddr
if is_addtid:
if is_store: return f"{name} {vsrc_s}, {saddr_s}{off_s}" + (f" {mods}" if mods else "")
return f"{name} {vdst_s}, {saddr_s}{off_s}" + (f" {mods}" if mods else "")
# Regular instructions need comma before saddr
saddr_s = f", {saddr_s}" if saddr_s else ""
if is_atomic:
if inst.th == 1: # TH_ATOMIC_RETURN
return f"{name} {vdst_s}, {vaddr_s}, {vsrc_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "")
return f"{name} {vaddr_s}, {vsrc_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "")
if is_store: return f"{name} {vaddr_s}, {vsrc_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "")
return f"{name} {vdst_s}, {vaddr_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "")
def _disasm_vbuffer(inst) -> str:
"""Disassemble RDNA4 VBUFFER instructions."""
from extra.assembly.amd.autogen.rdna4.enum import VBUFFEROp
name = VBUFFEROp(inst.op).name.lower()
# Determine if this is a typed buffer instruction (MTBUF format)
is_format = 'format' in name
# Data width based on instruction name
if is_format:
w = {'x': 1, 'xy': 2, 'xyz': 3, 'xyzw': 4}.get(name.split('_')[-1], 1)
if 'd16' in name: w = (w + 1) // 2
else:
suffix = name.split('_')[-1]
w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'b8':1,'b16':1,'f16':1,'f32':1,'bf16':1}.get(suffix, 1)
if 'cmpswap' in name: w *= 2
if inst.tfe: w += 1
vdata_s = _vreg(inst.vdata, w)
vaddr_s = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else (f"v{inst.vaddr}" if inst.offen or inst.idxen else "off")
# RDNA4 VBUFFER rsrc field stores the SGPR index directly (not /4 like RDNA3)
srsrc_s = _sreg_or_ttmp(inst.rsrc, 4)
soffset_s = decode_src(inst.soffset)
off = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000
is_store, is_atomic = 'store' in name, 'atomic' in name
mods = _rdna4_mem_mods(inst.th, inst.scope, is_store, is_atomic)
# Format field is only for MTBUF (tbuffer_*) instructions, not buffer_*_format_* instructions
# We don't output format for buffer instructions since they use implicit format
parts = []
if inst.idxen: parts.append("idxen")
if inst.offen: parts.append("offen")
if off: parts.append(f"offset:{off}")
if mods: parts.append(mods)
if inst.tfe: parts.append("tfe")
return f"{name} {vdata_s}, {vaddr_s}, {srsrc_s}, {soffset_s}" + (f" {' '.join(parts)}" if parts else "")
def _disasm_vimage(inst) -> str:
"""Disassemble RDNA4 VIMAGE instructions."""
from extra.assembly.amd.autogen.rdna4.enum import VIMAGEOp
name = VIMAGEOp(inst.op).name.lower()
# RDNA4 VIMAGE rsrc field stores the SGPR index directly (not /4 like RDNA3)
if 'bvh' in name:
# BVH intersect ray: special format with individual/range vaddr components
# Format: [node_ptr, ray_extent, ray_origin(3), ray_dir(3), ray_inv_dir(3)]
# bvh64 has 2-VGPR node_ptr, a16 removes ray_inv_dir
if 'dual' in name or 'bvh8' in name:
# dual/bvh8: [node_ptr(2), ray_extent(2), ray_origin(3), ray_dir(3), ...]
parts = [_vreg(inst.vaddr0, 2), _vreg(inst.vaddr1, 2), _vreg(inst.vaddr2, 3), _vreg(inst.vaddr3, 3)]
if not inst.a16: parts.append(_vreg(inst.vaddr4, 1 if 'bvh8' in name else 2))
dst_w = 10
elif '64' in name:
parts = [_vreg(inst.vaddr0, 2), f"v{inst.vaddr1}", _vreg(inst.vaddr2, 3), _vreg(inst.vaddr3, 3)]
if not inst.a16: parts.append(_vreg(inst.vaddr4, 3))
dst_w = 4
else:
parts = [f"v{inst.vaddr0}", f"v{inst.vaddr1}", _vreg(inst.vaddr2, 3), _vreg(inst.vaddr3, 3)]
if not inst.a16: parts.append(_vreg(inst.vaddr4, 3))
dst_w = 4
return f"{name} {_vreg(inst.vdata, dst_w)}, [{', '.join(parts)}], {_sreg_or_ttmp(inst.rsrc, 4)}{' a16' if inst.a16 else ''}"
# vdata width - msaa_load always uses 4 VGPRs per channel
vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1)
if inst.d16: vdata = (vdata + 1) // 2
if inst.tfe: vdata += 1
# dim names
dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"{inst.dim}"
# vaddr width calculation (RDNA4 uses vaddr0-4 for address components)
vaddr_w = _mimg_vaddr_width(name, inst.dim, inst.a16)
if vaddr_w == 1: vaddr_s = f"v{inst.vaddr0}"
elif vaddr_w == 2: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}]"
elif vaddr_w == 3: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}]"
elif vaddr_w == 4: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, v{inst.vaddr3}]"
else: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, v{inst.vaddr3}, v{inst.vaddr4}]"
srsrc_s = _sreg_or_ttmp(inst.rsrc, 8)
# RDNA4 always requires dmask for size calculation (even if 0xf)
mods = [f"dmask:0x{inst.dmask:x}"]
mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}")
# Add th/scope before other modifiers, then r128, then a16/tfe/d16 (LLVM expects this order)
if inst.th or inst.scope:
is_store, is_atomic = 'store' in name, 'atomic' in name
mem_mods = _rdna4_mem_mods(inst.th, inst.scope, is_store, is_atomic)
if mem_mods: mods.append(mem_mods)
if inst.r128: mods.append("r128")
mods.extend([m for c, m in [(inst.a16, "a16"), (inst.tfe, "tfe"), (inst.d16, "d16")] if c])
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_s}, {srsrc_s} {' '.join(mods)}"
def _disasm_vsample(inst) -> str:
"""Disassemble RDNA4 VSAMPLE instructions."""
from extra.assembly.amd.autogen.rdna4.enum import VSAMPLEOp
name = VSAMPLEOp(inst.op).name.lower()
# vdata width
vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1)
if inst.d16: vdata = (vdata + 1) // 2
if inst.tfe: vdata += 1
dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"{inst.dim}"
vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16)
if vaddr == 1: vaddr_s = f"v{inst.vaddr0}"
elif vaddr == 2: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}]"
elif vaddr == 3: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}]"
elif vaddr == 4: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, v{inst.vaddr3}]"
else:
# More than 4 vaddrs: vaddr3 becomes start of a contiguous range for the remaining coords
extra = vaddr - 3 # vaddr0-2 are individual, vaddr3 starts range of remaining
vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, {_vreg(inst.vaddr3, extra)}]"
# RDNA4 VSAMPLE rsrc/samp fields store the SGPR index directly (not /4 like RDNA3)
srsrc_s = _sreg_or_ttmp(inst.rsrc, 8)
# msaa_load doesn't use a sampler (it's a load, not a sample), but gather4h does
uses_sampler = 'msaa_load' not in name
ssamp_s = f", {_sreg_or_ttmp(inst.samp, 4)}" if uses_sampler else ""
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask else []
mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}")
if inst.unrm: mods.append("unorm")
# th/scope must come before r128, a16, tfe, lwe, d16
if inst.th or inst.scope:
mem_mods = _rdna4_mem_mods(inst.th, inst.scope, False, False)
if mem_mods: mods.append(mem_mods)
mods.extend([m for c, m in [(inst.r128, "r128"), (inst.a16, "a16"), (inst.tfe, "tfe"), (inst.lwe, "lwe"), (inst.d16, "d16")] if c])
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_s}, {srsrc_s}{ssamp_s} {' '.join(mods)}"
DISASM_HANDLERS = {VOP1: _disasm_vop1, VOP2: _disasm_vop2, VOPC: _disasm_vopc, VOP3: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOPD: _disasm_vopd, VOP3P: _disasm_vop3p,
VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, MUBUF: _disasm_buf, MTBUF: _disasm_buf,
MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk}
MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk, EXP: _disasm_exp, LDSDIR: _disasm_ldsdir}
def disasm(inst: Inst) -> str: return DISASM_HANDLERS[type(inst)](inst)
# RDNA4 uses different class names, dispatch by name for cross-arch support
_DISASM_BY_NAME = {'VOP1': _disasm_vop1, 'VOP2': _disasm_vop2, 'VOPC': _disasm_vopc, 'VOP3': _disasm_vop3, 'VOP3SD': _disasm_vop3sd,
'VOPD': _disasm_vopd, 'VOP3P': _disasm_vop3p, 'VINTERP': _disasm_vinterp, 'SOPP': _disasm_sopp, 'SMEM': _disasm_smem,
'DS': _disasm_ds, 'VDS': _disasm_ds, 'FLAT': _disasm_flat, 'MUBUF': _disasm_buf, 'MTBUF': _disasm_buf, 'MIMG': _disasm_mimg,
'SOP1': _disasm_sop1, 'SOP2': _disasm_sop2, 'SOPC': _disasm_sopc, 'SOPK': _disasm_sopk,
'EXP': _disasm_exp, 'LDSDIR': _disasm_ldsdir, 'VEXPORT': _disasm_exp, 'VDSDIR': _disasm_ldsdir,
'VFLAT': _disasm_vflat, 'VGLOBAL': _disasm_vflat, 'VSCRATCH': _disasm_vflat,
'VBUFFER': _disasm_vbuffer, 'VIMAGE': _disasm_vimage, 'VSAMPLE': _disasm_vsample}
def disasm(inst: Inst, wave_size: int = 32) -> str:
handler = DISASM_HANDLERS.get(type(inst)) or _DISASM_BY_NAME.get(type(inst).__name__)
if handler is None: raise KeyError(f"no disasm handler for {type(inst).__name__}")
# For VOP3P (includes WMMA), pass wave_size if handler supports it
if handler == _disasm_vop3p: return _disasm_vop3p(inst, wave_size)
return handler(inst)
# ═══════════════════════════════════════════════════════════════════════════════
# ASSEMBLER

View file

@ -1,7 +1,7 @@
# autogenerated from AMD RDNA4 ISA PDF by pdf.py - do not edit
# ruff: noqa: F401,F403
from typing import Annotated
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, Inst96, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField
from extra.assembly.amd.autogen.rdna4.enum import *
import functools
@ -69,7 +69,7 @@ class SOPP(Inst32):
op:Annotated[BitField, SOPPOp] = bits[22:16]
simm16:SImm = bits[15:0]
class VBUFFER(Inst64):
class VBUFFER(Inst96):
encoding = bits[31:26] == 0b110001
soffset:SSrc = bits[6:0]
op:Annotated[BitField, VBUFFEROp] = bits[21:14]
@ -94,17 +94,14 @@ class VDS(Inst64):
data1:VGPRField = bits[55:48]
vdst:VGPRField = bits[63:56]
class VDSDIR(Inst64):
encoding = bits[31:24] == 0b11001101
class VDSDIR(Inst32):
encoding = bits[31:24] == 0b11001110
vdst:VGPRField = bits[7:0]
waitexp = bits[10:8]
opsel = bits[14:11]
cm = bits[15]
op:Annotated[BitField, VDSDIROp] = bits[20:16]
src0:Src = bits[40:32]
src1:Src = bits[49:41]
src2:Src = bits[58:50]
neg = bits[63:61]
attr_chan = bits[9:8]
attr = bits[15:10]
wait_va = bits[19:16]
op:Annotated[BitField, VDSDIROp] = bits[21:20]
wait_vm = bits[23] # single bit, bit 22 is reserved
class VEXPORT(Inst64):
encoding = bits[31:26] == 0b111110
@ -117,6 +114,49 @@ class VEXPORT(Inst64):
vsrc2 = bits[55:48]
vsrc3 = bits[63:56]
class VFLAT(Inst96):
encoding = bits[31:24] == 0b11101100
saddr:SSrc = bits[6:0]
op:Annotated[BitField, VFLATOp] = bits[20:14]
vdst:VGPRField = bits[39:32]
sve = bits[49]
scope = bits[51:50]
th = bits[54:52]
vsrc = bits[62:55]
vaddr:VGPRField = bits[71:64]
ioffset = bits[95:72]
class VGLOBAL(Inst96):
encoding = bits[31:24] == 0b11101110
saddr:SSrc = bits[6:0]
op:Annotated[BitField, VGLOBALOp] = bits[20:14]
vdst:VGPRField = bits[39:32]
sve = bits[49]
scope = bits[51:50]
th = bits[54:52]
vsrc = bits[62:55]
vaddr:VGPRField = bits[71:64]
ioffset = bits[95:72]
class VIMAGE(Inst96):
encoding = bits[31:26] == 0b110100
dim = bits[2:0]
r128 = bits[4]
d16 = bits[5]
a16 = bits[6]
op:Annotated[BitField, VIMAGEOp] = bits[21:14]
dmask = bits[25:22]
vdata:VGPRField = bits[39:32]
rsrc = bits[49:41]
scope = bits[51:50]
th = bits[54:52]
tfe = bits[55]
vaddr4 = bits[63:56]
vaddr0 = bits[71:64]
vaddr1 = bits[79:72]
vaddr2 = bits[87:80]
vaddr3 = bits[95:88]
class VINTERP(Inst64):
encoding = bits[31:24] == 0b11001101
op:Annotated[BitField, VINTERPOp] = bits[20:16]
@ -199,6 +239,39 @@ class VOPD(Inst64):
srcy0:Src = bits[40:32]
vsrcy1:VGPRField = bits[48:41]
class VSAMPLE(Inst96):
encoding = bits[31:26] == 0b111001
dim = bits[2:0]
tfe = bits[3]
r128 = bits[4]
d16 = bits[5]
a16 = bits[6]
unrm = bits[13]
op:Annotated[BitField, VSAMPLEOp] = bits[21:14]
dmask = bits[25:22]
vdata:VGPRField = bits[39:32]
lwe = bits[40]
rsrc = bits[49:41]
scope = bits[51:50]
th = bits[54:52]
samp = bits[63:55]
vaddr0 = bits[71:64]
vaddr1 = bits[79:72]
vaddr2 = bits[87:80]
vaddr3 = bits[95:88]
class VSCRATCH(Inst96):
encoding = bits[31:24] == 0b11101101
saddr:SSrc = bits[6:0]
op:Annotated[BitField, VSCRATCHOp] = bits[20:14]
vdst:VGPRField = bits[39:32]
sve = bits[49]
scope = bits[51:50]
th = bits[54:52]
vsrc = bits[62:55]
vaddr:VGPRField = bits[71:64]
ioffset = bits[95:72]
# instruction helpers
s_load_b32 = functools.partial(SMEM, SMEMOp.S_LOAD_B32)
s_load_b64 = functools.partial(SMEM, SMEMOp.S_LOAD_B64)
@ -571,6 +644,159 @@ tbuffer_store_d16_format_xyz = functools.partial(VBUFFER, VBUFFEROp.TBUFFER_STOR
tbuffer_store_d16_format_xyzw = functools.partial(VBUFFER, VBUFFEROp.TBUFFER_STORE_D16_FORMAT_XYZW)
ds_param_load = functools.partial(VDSDIR, VDSDIROp.DS_PARAM_LOAD)
ds_direct_load = functools.partial(VDSDIR, VDSDIROp.DS_DIRECT_LOAD)
flat_load_u8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_U8)
flat_load_i8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_I8)
flat_load_u16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_U16)
flat_load_i16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_I16)
flat_load_b32 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B32)
flat_load_b64 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B64)
flat_load_b96 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B96)
flat_load_b128 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B128)
flat_store_b8 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B8)
flat_store_b16 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B16)
flat_store_b32 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B32)
flat_store_b64 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B64)
flat_store_b96 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B96)
flat_store_b128 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B128)
flat_load_d16_u8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_U8)
flat_load_d16_i8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_I8)
flat_load_d16_b16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_B16)
flat_load_d16_hi_u8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_U8)
flat_load_d16_hi_i8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_I8)
flat_load_d16_hi_b16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_B16)
flat_store_d16_hi_b8 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_D16_HI_B8)
flat_store_d16_hi_b16 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_D16_HI_B16)
flat_atomic_swap_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SWAP_B32)
flat_atomic_cmpswap_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_CMPSWAP_B32)
flat_atomic_add_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_ADD_U32)
flat_atomic_sub_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SUB_U32)
flat_atomic_sub_clamp_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SUB_CLAMP_U32)
flat_atomic_min_i32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_I32)
flat_atomic_min_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_U32)
flat_atomic_max_i32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_I32)
flat_atomic_max_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_U32)
flat_atomic_and_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_AND_B32)
flat_atomic_or_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_OR_B32)
flat_atomic_xor_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_XOR_B32)
flat_atomic_inc_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_INC_U32)
flat_atomic_dec_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_DEC_U32)
flat_atomic_swap_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SWAP_B64)
flat_atomic_cmpswap_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_CMPSWAP_B64)
flat_atomic_add_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_ADD_U64)
flat_atomic_sub_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SUB_U64)
flat_atomic_min_i64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_I64)
flat_atomic_min_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_U64)
flat_atomic_max_i64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_I64)
flat_atomic_max_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_U64)
flat_atomic_and_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_AND_B64)
flat_atomic_or_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_OR_B64)
flat_atomic_xor_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_XOR_B64)
flat_atomic_inc_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_INC_U64)
flat_atomic_dec_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_DEC_U64)
flat_atomic_cond_sub_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_COND_SUB_U32)
flat_atomic_min_num_f32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_NUM_F32)
flat_atomic_max_num_f32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_NUM_F32)
flat_atomic_add_f32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_ADD_F32)
flat_atomic_pk_add_f16 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_PK_ADD_F16)
flat_atomic_pk_add_bf16 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_PK_ADD_BF16)
global_load_u8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_U8)
global_load_i8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_I8)
global_load_u16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_U16)
global_load_i16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_I16)
global_load_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B32)
global_load_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B64)
global_load_b96 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B96)
global_load_b128 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B128)
global_store_b8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B8)
global_store_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B16)
global_store_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B32)
global_store_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B64)
global_store_b96 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B96)
global_store_b128 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B128)
global_load_d16_u8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_U8)
global_load_d16_i8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_I8)
global_load_d16_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_B16)
global_load_d16_hi_u8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_HI_U8)
global_load_d16_hi_i8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_HI_I8)
global_load_d16_hi_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_HI_B16)
global_store_d16_hi_b8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_D16_HI_B8)
global_store_d16_hi_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_D16_HI_B16)
global_load_addtid_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_ADDTID_B32)
global_store_addtid_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_ADDTID_B32)
global_inv = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_INV)
global_wb = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_WB)
global_atomic_swap_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SWAP_B32)
global_atomic_cmpswap_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_CMPSWAP_B32)
global_atomic_add_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ADD_U32)
global_atomic_sub_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SUB_U32)
global_atomic_sub_clamp_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SUB_CLAMP_U32)
global_atomic_min_i32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_I32)
global_atomic_min_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_U32)
global_atomic_max_i32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_I32)
global_atomic_max_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_U32)
global_atomic_and_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_AND_B32)
global_atomic_or_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_OR_B32)
global_atomic_xor_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_XOR_B32)
global_atomic_inc_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_INC_U32)
global_atomic_dec_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_DEC_U32)
global_atomic_swap_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SWAP_B64)
global_atomic_cmpswap_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_CMPSWAP_B64)
global_atomic_add_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ADD_U64)
global_atomic_sub_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SUB_U64)
global_atomic_min_i64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_I64)
global_atomic_min_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_U64)
global_atomic_max_i64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_I64)
global_atomic_max_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_U64)
global_atomic_and_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_AND_B64)
global_atomic_or_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_OR_B64)
global_atomic_xor_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_XOR_B64)
global_atomic_inc_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_INC_U64)
global_atomic_dec_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_DEC_U64)
global_wbinv = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_WBINV)
global_atomic_cond_sub_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_COND_SUB_U32)
global_atomic_min_num_f32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_NUM_F32)
global_atomic_max_num_f32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_NUM_F32)
global_load_block = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_BLOCK)
global_store_block = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_BLOCK)
global_atomic_add_f32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ADD_F32)
global_load_tr_b128 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_TR_B128)
global_load_tr_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_TR_B64)
global_atomic_pk_add_f16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_PK_ADD_F16)
global_atomic_pk_add_bf16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_PK_ADD_BF16)
global_atomic_ordered_add_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ORDERED_ADD_B64)
image_load = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD)
image_load_mip = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD_MIP)
image_load_pck = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD_PCK)
image_load_pck_sgn = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD_PCK_SGN)
image_load_mip_pck = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD_MIP_PCK)
image_load_mip_pck_sgn = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD_MIP_PCK_SGN)
image_store = functools.partial(VIMAGE, VIMAGEOp.IMAGE_STORE)
image_store_mip = functools.partial(VIMAGE, VIMAGEOp.IMAGE_STORE_MIP)
image_store_pck = functools.partial(VIMAGE, VIMAGEOp.IMAGE_STORE_PCK)
image_store_mip_pck = functools.partial(VIMAGE, VIMAGEOp.IMAGE_STORE_MIP_PCK)
image_atomic_swap = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_SWAP)
image_atomic_cmpswap = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_CMPSWAP)
image_atomic_add_uint = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_ADD_UINT)
image_atomic_sub_uint = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_SUB_UINT)
image_atomic_min_int = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_MIN_INT)
image_atomic_min_uint = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_MIN_UINT)
image_atomic_max_int = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_MAX_INT)
image_atomic_max_uint = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_MAX_UINT)
image_atomic_and = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_AND)
image_atomic_or = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_OR)
image_atomic_xor = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_XOR)
image_atomic_inc_uint = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_INC_UINT)
image_atomic_dec_uint = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_DEC_UINT)
image_get_resinfo = functools.partial(VIMAGE, VIMAGEOp.IMAGE_GET_RESINFO)
image_bvh_intersect_ray = functools.partial(VIMAGE, VIMAGEOp.IMAGE_BVH_INTERSECT_RAY)
image_bvh64_intersect_ray = functools.partial(VIMAGE, VIMAGEOp.IMAGE_BVH64_INTERSECT_RAY)
image_bvh_dual_intersect_ray = functools.partial(VIMAGE, VIMAGEOp.IMAGE_BVH_DUAL_INTERSECT_RAY)
image_bvh8_intersect_ray = functools.partial(VIMAGE, VIMAGEOp.IMAGE_BVH8_INTERSECT_RAY)
image_atomic_add_flt = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_ADD_FLT)
image_atomic_min_flt = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_MIN_FLT)
image_atomic_max_flt = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_MAX_FLT)
image_atomic_pk_add_f16 = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_PK_ADD_F16)
image_atomic_pk_add_bf16 = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_PK_ADD_BF16)
v_interp_p10_f32 = functools.partial(VINTERP, VINTERPOp.V_INTERP_P10_F32)
v_interp_p2_f32 = functools.partial(VINTERP, VINTERPOp.V_INTERP_P2_F32)
v_interp_p10_f16_f32 = functools.partial(VINTERP, VINTERPOp.V_INTERP_P10_F16_F32)
@ -1396,6 +1622,88 @@ v_dual_dot2acc_f32_bf16 = functools.partial(VOPD, VOPDOp.V_DUAL_DOT2ACC_F32_BF16
v_dual_add_nc_u32 = functools.partial(VOPD, VOPDOp.V_DUAL_ADD_NC_U32)
v_dual_lshlrev_b32 = functools.partial(VOPD, VOPDOp.V_DUAL_LSHLREV_B32)
v_dual_and_b32 = functools.partial(VOPD, VOPDOp.V_DUAL_AND_B32)
image_msaa_load = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_MSAA_LOAD)
image_sample = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE)
image_sample_d = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D)
image_sample_l = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_L)
image_sample_b = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_B)
image_sample_lz = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_LZ)
image_sample_c = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C)
image_sample_c_d = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D)
image_sample_c_l = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_L)
image_sample_c_b = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_B)
image_sample_c_lz = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_LZ)
image_sample_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_O)
image_sample_d_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D_O)
image_sample_l_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_L_O)
image_sample_b_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_B_O)
image_sample_lz_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_LZ_O)
image_sample_c_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_O)
image_sample_c_d_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D_O)
image_sample_c_l_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_L_O)
image_sample_c_b_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_B_O)
image_sample_c_lz_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_LZ_O)
image_gather4 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4)
image_gather4_l = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_L)
image_gather4_b = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_B)
image_gather4_lz = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_LZ)
image_gather4_c = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C)
image_gather4_c_lz = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_LZ)
image_gather4_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_O)
image_gather4_lz_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_LZ_O)
image_gather4_c_lz_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_LZ_O)
image_get_lod = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GET_LOD)
image_sample_d_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D_G16)
image_sample_c_d_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D_G16)
image_sample_d_o_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D_O_G16)
image_sample_c_d_o_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D_O_G16)
image_sample_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_CL)
image_sample_d_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D_CL)
image_sample_b_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_B_CL)
image_sample_c_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_CL)
image_sample_c_d_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D_CL)
image_sample_c_b_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_B_CL)
image_sample_cl_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_CL_O)
image_sample_d_cl_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D_CL_O)
image_sample_b_cl_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_B_CL_O)
image_sample_c_cl_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_CL_O)
image_sample_c_d_cl_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D_CL_O)
image_sample_c_b_cl_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_B_CL_O)
image_sample_c_d_cl_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D_CL_G16)
image_sample_d_cl_o_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D_CL_O_G16)
image_sample_c_d_cl_o_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D_CL_O_G16)
image_sample_d_cl_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D_CL_G16)
image_gather4_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_CL)
image_gather4_b_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_B_CL)
image_gather4_c_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_CL)
image_gather4_c_l = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_L)
image_gather4_c_b = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_B)
image_gather4_c_b_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_B_CL)
image_gather4h = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4H)
scratch_load_u8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_U8)
scratch_load_i8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_I8)
scratch_load_u16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_U16)
scratch_load_i16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_I16)
scratch_load_b32 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B32)
scratch_load_b64 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B64)
scratch_load_b96 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B96)
scratch_load_b128 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B128)
scratch_store_b8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B8)
scratch_store_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B16)
scratch_store_b32 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B32)
scratch_store_b64 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B64)
scratch_store_b96 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B96)
scratch_store_b128 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B128)
scratch_load_d16_u8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_U8)
scratch_load_d16_i8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_I8)
scratch_load_d16_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_B16)
scratch_load_d16_hi_u8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_HI_U8)
scratch_load_d16_hi_i8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_HI_I8)
scratch_load_d16_hi_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_HI_B16)
scratch_store_d16_hi_b8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_D16_HI_B8)
scratch_store_d16_hi_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_D16_HI_B16)
scratch_load_block = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_BLOCK)
scratch_store_block = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_BLOCK)
VCC_LO = SrcEnum.VCC_LO
VCC_HI = SrcEnum.VCC_HI

View file

@ -62,8 +62,10 @@ _SPECIAL_REGS = {
'V_CMP_CLASS_F64': (1, 2, 1, 1), 'V_CMPX_CLASS_F64': (1, 2, 1, 1),
'V_CMP_CLASS_F32': (1, 1, 1, 1), 'V_CMPX_CLASS_F32': (1, 1, 1, 1),
'V_CMP_CLASS_F16': (1, 1, 1, 1), 'V_CMPX_CLASS_F16': (1, 1, 1, 1),
'V_MAD_U64_U32': (2, 1, 1, 2), 'V_MAD_I64_I32': (2, 1, 1, 2),
'V_MAD_U64_U32': (2, 1, 1, 2), 'V_MAD_I64_I32': (2, 1, 1, 2), 'V_MAD_CO_U64_U32': (2, 1, 1, 2), 'V_MAD_CO_I64_I32': (2, 1, 1, 2),
'V_QSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_U32_U8': (4, 2, 1, 4),
# RDNA4 CVT_PK_F32 instructions output 2 F32 values (64-bit)
'V_CVT_PK_F32_BF8': (2, 1, 1, 1), 'V_CVT_PK_F32_FP8': (2, 1, 1, 1),
}
_SPECIAL_DTYPE = {
'V_LSHLREV_B64': ('B64', 'U32', 'B64', None), 'V_LSHRREV_B64': ('B64', 'U32', 'B64', None), 'V_ASHRREV_I64': ('I64', 'U32', 'I64', None),
@ -78,6 +80,8 @@ _SPECIAL_DTYPE = {
'V_MAD_U64_U32': ('U64', 'U32', 'U32', 'U64'), 'V_MAD_I64_I32': ('I64', 'I32', 'I32', 'I64'),
'V_QSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'), 'V_MQSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'),
'V_MQSAD_U32_U8': ('B128', 'B64', 'B64', 'B128'),
# RDNA4 CVT_PK_F32 instructions: source is 8-bit packed as 16-bit operand
'V_CVT_PK_F32_BF8': ('F32', 'B16', None, None), 'V_CVT_PK_F32_FP8': ('F32', 'B16', None, None),
}
@cache
def spec_regs(name: str) -> tuple[int, int, int, int]:
@ -108,7 +112,7 @@ def spec_is_16bit(name: str) -> bool:
def spec_is_64bit(name: str) -> bool: return bool(_F64_RE.search(name.upper()))
_3SRC = {'FMA', 'MAD', 'MIN3', 'MAX3', 'MED3', 'DIV_FIX', 'DIV_FMAS', 'DIV_SCALE', 'SAD', 'LERP', 'ALIGN', 'CUBE', 'BFE', 'BFI',
'PERM_B32', 'PERMLANE', 'CNDMASK', 'XOR3', 'OR3', 'ADD3', 'LSHL_OR', 'AND_OR', 'LSHL_ADD', 'ADD_LSHL', 'XAD', 'MAXMIN',
'MINMAX', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'}
'MINMAX', 'MAXIMUMMINIMUM', 'MINIMUMMAXIMUM', 'MAXIMUM3', 'MINIMUM3', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'}
_2SRC = {'FMAC'} # FMAC uses dst as implicit accumulator, so only 2 explicit sources
def spec_num_srcs(name: str) -> int:
name = name.upper()
@ -138,17 +142,35 @@ class BitField:
def __get__(self, obj: None, objtype: type) -> BitField: ...
@overload
def __get__(self, obj: object, objtype: type | None = None) -> int: ...
# Map RDNA4 class names to their corresponding enum names for op field dynamic lookup
_RDNA4_OP_ENUMS = {'VDS': 'DSOp', 'VBUFFER': 'VBUFFEROp', 'VEXPORT': 'EXPOp', 'VFLAT': 'VFLATOp', 'VGLOBAL': 'VGLOBALOp',
'VSCRATCH': 'VSCRATCHOp', 'VIMAGE': 'VIMAGEOp', 'VSAMPLE': 'VSAMPLEOp', 'VDSDIR': 'VDSDIROp'}
def __get__(self, obj, objtype=None):
if obj is None: return self
val = unwrap(obj._values.get(self.name, 0))
# Convert to IntEnum if marker is an IntEnum subclass
if self.marker and isinstance(self.marker, type) and issubclass(self.marker, IntEnum):
# VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
if self.marker is VOP3Op:
if val < 256: return VOPCOp(val)
if val in Inst._VOP3SD_OPS: return VOP3SDOp(val)
# Check by name to handle both RDNA3 and RDNA4 enums
if self.marker.__name__ == 'VOP3Op':
# Get the appropriate enums from the same module as the marker
marker_mod = self.marker.__module__
import importlib
enum_mod = importlib.import_module(marker_mod)
if val < 256: return enum_mod.VOPCOp(val)
if val in Inst._VOP3SD_OPS: return enum_mod.VOP3SDOp(val)
try: return self.marker(val)
except ValueError: pass
# For RDNA4 op fields without type annotations, dynamically look up enum
elif self.name == 'op' and 'rdna4' in obj.__class__.__module__:
import importlib
enum_mod = importlib.import_module('extra.assembly.amd.autogen.rdna4.enum')
cls_name = obj.__class__.__name__
enum_name = self._RDNA4_OP_ENUMS.get(cls_name, cls_name + 'Op')
if hasattr(enum_mod, enum_name):
try: return getattr(enum_mod, enum_name)(val)
except ValueError: pass
return val
class _Bits:
@ -429,7 +451,7 @@ class Inst:
return result + (lit32 & MASK32).to_bytes(4, 'little')
@classmethod
def _size(cls) -> int: return 4 if issubclass(cls, Inst32) else 8
def _size(cls) -> int: return 4 if issubclass(cls, Inst32) else 12 if issubclass(cls, Inst96) else 8
def size(self) -> int:
# Literal is always 4 bytes in the binary (for 64-bit ops, it's in high 32 bits)
return self._size() + (4 if self._literal is not None else 0)
@ -489,9 +511,9 @@ class Inst:
def __hash__(self): return hash((self.__class__.__name__, tuple(sorted((k, repr(v)) for k, v in self._values.items())), self._literal))
def disasm(self) -> str:
def disasm(self, wave_size: int = 32) -> str:
from extra.assembly.amd.asm import disasm
return disasm(self)
return disasm(self, wave_size)
_enum_map = {'VOP1': VOP1Op, 'VOP2': VOP2Op, 'VOP3': VOP3Op, 'VOP3SD': VOP3SDOp, 'VOP3P': VOP3POp, 'VOPC': VOPCOp,
'SOP1': SOP1Op, 'SOP2': SOP2Op, 'SOPC': SOPCOp, 'SOPK': SOPKOp, 'SOPP': SOPPOp,
@ -499,6 +521,10 @@ class Inst:
'VOPD': VOPDOp, 'VINTERP': VINTERPOp}
_VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
# Map RDNA4 class names to their corresponding enum names
_rdna4_enum_names = {'VDS': 'DSOp', 'VBUFFER': 'VBUFFEROp', 'VEXPORT': 'EXPOp', 'VFLAT': 'VFLATOp', 'VGLOBAL': 'VGLOBALOp',
'VSCRATCH': 'VSCRATCHOp', 'VIMAGE': 'VIMAGEOp', 'VSAMPLE': 'VSAMPLEOp', 'VDSDIR': 'VDSDIROp'}
@property
def op(self):
"""Return the op as an enum (e.g., VOP1Op.V_MOV_B32). VOP3 returns VOPCOp/VOP3SDOp for those op ranges."""
@ -506,6 +532,20 @@ class Inst:
if val is None: return None
if hasattr(val, 'name'): return val # already an enum
cls_name = self.__class__.__name__
# First check if op field has an annotated enum type
import typing
if 'op' in self.__class__.__annotations__:
ann = self.__class__.__annotations__['op']
if hasattr(ann, '__metadata__'):
for m in typing.get_args(ann)[1:]:
if isinstance(m, type) and issubclass(m, IntEnum): return m(val)
# Check if this is an RDNA4 class (module path contains rdna4) and get enum from its module
if 'rdna4' in self.__class__.__module__:
import importlib
enum_mod = importlib.import_module('extra.assembly.amd.autogen.rdna4.enum')
enum_name = self._rdna4_enum_names.get(cls_name, cls_name + 'Op')
if hasattr(enum_mod, enum_name): return getattr(enum_mod, enum_name)(val)
# Fall back to static enum map
assert cls_name in self._enum_map, f"no enum map for {cls_name}"
return self._enum_map[cls_name](val)
@ -531,3 +571,4 @@ class Inst:
class Inst32(Inst): pass
class Inst64(Inst): pass
class Inst96(Inst): pass

View file

@ -220,8 +220,16 @@ def _parse_fields_table(table: list, fmt: str, enums: set[str]) -> list[tuple]:
if not (bits := _parse_bits(bits_str)): continue
enc_val, hi, lo = None, bits[0], bits[1]
if name == 'ENCODING' and row[2]:
if m := re.search(r"(?:'b|Must be:\s*)([01_]+)", row[2]):
desc = row[2]
# Handle shared FLAT/GLOBAL/SCRATCH table: look for format-specific encoding
fmt_key = fmt.lstrip('V').lower().capitalize() # VFLAT -> Flat, VGLOBAL -> Global
if m := re.search(rf"{fmt_key}='b([01_]+)", desc):
enc_bits = m.group(1).replace('_', '')
elif m := re.search(r"(?:'b|Must be:\s*)([01_]+)", desc):
enc_bits = m.group(1).replace('_', '')
else:
enc_bits = None
if enc_bits:
enc_val, declared_width, actual_width = int(enc_bits, 2), hi - lo + 1, len(enc_bits)
if actual_width > declared_width: lo = hi - actual_width + 1
ftype = f"{fmt}Op" if name == 'OP' and f"{fmt}Op" in enums else FIELD_TYPES.get(name.upper())
@ -292,6 +300,16 @@ def _parse_single_pdf(url: str):
next_text = pdf.text(microcode_start + i + 1).lstrip()
if next_text.startswith('Description') or (next_text.startswith('"RDNA') and 'Description' in next_text[:200]):
format_headers.append((fmt_name, i, m.start()))
# RDNA4: Look for "Table X. Y Fields" patterns (e.g., VIMAGE, VSAMPLE, or shared FLAT/GLOBAL/SCRATCH)
for m in re.finditer(r'Table \d+\.\s+([\w,\s]+?)\s+Fields', text):
table_name = m.group(1).strip()
# Handle shared table like "FLAT, GLOBAL and SCRATCH"
if ',' in table_name or ' and ' in table_name:
for part in re.split(r',\s*|\s+and\s+', table_name):
fmt_name = 'V' + part.strip()
if fmt_name not in [h[0] for h in format_headers]: format_headers.append((fmt_name, i, m.start()))
elif table_name.startswith('V'):
if table_name not in [h[0] for h in format_headers]: format_headers.append((table_name, i, m.start()))
formats: dict[str, list] = {}
for fmt_name, rel_idx, header_pos in format_headers:
@ -324,6 +342,10 @@ def _parse_single_pdf(url: str):
if 'SMEM' in formats:
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
for n, h, l, e, t in formats['SMEM']]
# RDNA4: VFLAT/VGLOBAL/VSCRATCH OP field is [20:14] not [20:13] (PDF documentation error)
for fmt_name in ['VFLAT', 'VGLOBAL', 'VSCRATCH']:
if fmt_name in formats:
formats[fmt_name] = [(n, h, 14 if n == 'OP' else l, e, t) for n, h, l, e, t in formats[fmt_name]]
if doc_name in ('RDNA3', 'RDNA3.5'):
if 'SOPPOp' in enums: assert 8 not in enums['SOPPOp']; enums['SOPPOp'][8] = 'S_WAITCNT_DEPCTR'
if 'DSOp' in enums:
@ -409,13 +431,14 @@ def _generate_ins_py(formats, enums, src_enum, doc_name) -> str:
def field_key(f, order): return order.index(f[0].lower()) if f[0].lower() in order else 1000
lines = [f"# autogenerated from AMD {doc_name} ISA PDF by pdf.py - do not edit",
"# ruff: noqa: F401,F403", "from typing import Annotated",
"from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField",
"from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, Inst96, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField",
"from extra.assembly.amd.autogen.{arch}.enum import *",
"import functools", ""]
format_defaults = {'VOP3P': {'opsel_hi': 3, 'opsel_hi2': 1}}
lines.append("# instruction formats")
for fmt_name, fields in sorted(formats.items()):
base = "Inst64" if max(f[1] for f in fields) > 31 or fmt_name == 'VOP3SD' else "Inst32"
max_bit = max(f[1] for f in fields)
base = "Inst96" if max_bit > 63 else "Inst64" if max_bit > 31 or fmt_name == 'VOP3SD' else "Inst32"
order = FIELD_ORDER.get(fmt_name, [])
lines.append(f"class {fmt_name}({base}):")
if enc := next((f for f in fields if f[0] == 'ENCODING'), None):

View file

@ -1,98 +1,76 @@
#!/usr/bin/env python3
"""Test RDNA3 assembler/disassembler against LLVM test vectors."""
"""Test RDNA3/RDNA4 assembler/disassembler against LLVM test vectors."""
import unittest, re, subprocess
from tinygrad.helpers import fetch
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.asm import asm
from extra.assembly.amd.test.helpers import get_llvm_mc
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
# Format info: (filename, format_class, op_enum)
LLVM_TEST_FILES = {
# Scalar ALU
'sop1': ('gfx11_asm_sop1.s', SOP1, SOP1Op),
'sop2': ('gfx11_asm_sop2.s', SOP2, SOP2Op),
'sopp': ('gfx11_asm_sopp.s', SOPP, SOPPOp),
'sopk': ('gfx11_asm_sopk.s', SOPK, SOPKOp),
'sopc': ('gfx11_asm_sopc.s', SOPC, SOPCOp),
# Vector ALU
'vop1': ('gfx11_asm_vop1.s', VOP1, VOP1Op),
'vop2': ('gfx11_asm_vop2.s', VOP2, VOP2Op),
'vopc': ('gfx11_asm_vopc.s', VOPC, VOPCOp),
'vop3': ('gfx11_asm_vop3.s', VOP3, VOP3Op),
'vop3p': ('gfx11_asm_vop3p.s', VOP3P, VOP3POp),
'vop3sd': ('gfx11_asm_vop3.s', VOP3SD, VOP3SDOp), # VOP3SD shares file with VOP3
'vinterp': ('gfx11_asm_vinterp.s', VINTERP, VINTERPOp),
'vopd': ('gfx11_asm_vopd.s', VOPD, VOPDOp),
'vopcx': ('gfx11_asm_vopcx.s', VOPC, VOPCOp), # VOPCX uses VOPC format
# VOP3 promotions (VOP1/VOP2/VOPC promoted to VOP3 encoding)
'vop3_from_vop1': ('gfx11_asm_vop3_from_vop1.s', VOP3, VOP3Op),
'vop3_from_vop2': ('gfx11_asm_vop3_from_vop2.s', VOP3, VOP3Op),
'vop3_from_vopc': ('gfx11_asm_vop3_from_vopc.s', VOP3, VOP3Op),
'vop3_from_vopcx': ('gfx11_asm_vop3_from_vopcx.s', VOP3, VOP3Op),
# Memory
'ds': ('gfx11_asm_ds.s', DS, DSOp),
'smem': ('gfx11_asm_smem.s', SMEM, SMEMOp),
'flat': ('gfx11_asm_flat.s', FLAT, FLATOp),
'mubuf': ('gfx11_asm_mubuf.s', MUBUF, MUBUFOp),
'mtbuf': ('gfx11_asm_mtbuf.s', MTBUF, MTBUFOp),
'mimg': ('gfx11_asm_mimg.s', MIMG, MIMGOp),
# WMMA (matrix multiply)
'wmma': ('gfx11_asm_wmma.s', VOP3P, VOP3POp),
# Additional features
'vop3_features': ('gfx11_asm_vop3_features.s', VOP3, VOP3Op),
'vop3p_features': ('gfx11_asm_vop3p_features.s', VOP3P, VOP3POp),
'vopd_features': ('gfx11_asm_vopd_features.s', VOPD, VOPDOp),
# Alias files (alternative mnemonics)
'vop3_alias': ('gfx11_asm_vop3_alias.s', VOP3, VOP3Op),
'vop3p_alias': ('gfx11_asm_vop3p_alias.s', VOP3P, VOP3POp),
'vopc_alias': ('gfx11_asm_vopc_alias.s', VOPC, VOPCOp),
'vopcx_alias': ('gfx11_asm_vopcx_alias.s', VOPC, VOPCOp),
'vinterp_alias': ('gfx11_asm_vinterp_alias.s', VINTERP, VINTERPOp),
'smem_alias': ('gfx11_asm_smem_alias.s', SMEM, SMEMOp),
'mubuf_alias': ('gfx11_asm_mubuf_alias.s', MUBUF, MUBUFOp),
'mtbuf_alias': ('gfx11_asm_mtbuf_alias.s', MTBUF, MTBUFOp),
RDNA3_TEST_FILES = {
'sop1': 'gfx11_asm_sop1.s', 'sop2': 'gfx11_asm_sop2.s', 'sopp': 'gfx11_asm_sopp.s', 'sopk': 'gfx11_asm_sopk.s', 'sopc': 'gfx11_asm_sopc.s',
'vop1': 'gfx11_asm_vop1.s', 'vop2': 'gfx11_asm_vop2.s', 'vopc': 'gfx11_asm_vopc.s', 'vop3': 'gfx11_asm_vop3.s', 'vop3p': 'gfx11_asm_vop3p.s',
'vinterp': 'gfx11_asm_vinterp.s', 'vopd': 'gfx11_asm_vopd.s', 'vopcx': 'gfx11_asm_vopcx.s',
'vop3_from_vop1': 'gfx11_asm_vop3_from_vop1.s', 'vop3_from_vop2': 'gfx11_asm_vop3_from_vop2.s',
'vop3_from_vopc': 'gfx11_asm_vop3_from_vopc.s', 'vop3_from_vopcx': 'gfx11_asm_vop3_from_vopcx.s',
'ds': 'gfx11_asm_ds.s', 'smem': 'gfx11_asm_smem.s', 'flat': 'gfx11_asm_flat.s',
'mubuf': 'gfx11_asm_mubuf.s', 'mtbuf': 'gfx11_asm_mtbuf.s', 'mimg': 'gfx11_asm_mimg.s', 'mimg_features': 'gfx11_asm_mimg_features.s', 'ldsdir': 'gfx11_asm_ldsdir.s',
'exp': 'gfx11_asm_exp.s', 'wmma': 'gfx11_asm_wmma.s',
'vop3_features': 'gfx11_asm_vop3_features.s', 'vop3p_features': 'gfx11_asm_vop3p_features.s', 'vopd_features': 'gfx11_asm_vopd_features.s',
'vop3_alias': 'gfx11_asm_vop3_alias.s', 'vop3p_alias': 'gfx11_asm_vop3p_alias.s', 'vopc_alias': 'gfx11_asm_vopc_alias.s',
'vopcx_alias': 'gfx11_asm_vopcx_alias.s', 'vinterp_alias': 'gfx11_asm_vinterp_alias.s',
'smem_alias': 'gfx11_asm_smem_alias.s', 'mubuf_alias': 'gfx11_asm_mubuf_alias.s', 'mtbuf_alias': 'gfx11_asm_mtbuf_alias.s',
}
def parse_llvm_tests(text: str) -> list[tuple[str, bytes]]:
RDNA4_TEST_FILES = {
'sop1': 'gfx12_asm_sop1.s', 'sop2': 'gfx12_asm_sop2.s', 'sop2_alias': 'gfx12_asm_sop2_alias.s',
'sopp': 'gfx12_asm_sopp.s', 'sopk': 'gfx12_asm_sopk.s', 'sopk_alias': 'gfx12_asm_sopk_alias.s', 'sopc': 'gfx12_asm_sopc.s',
'vop1': 'gfx12_asm_vop1.s', 'vop2': 'gfx12_asm_vop2.s', 'vop2_aliases': 'gfx12_asm_vop2_aliases.s',
'vopc': 'gfx12_asm_vopc.s', 'vopcx': 'gfx12_asm_vopcx.s',
'vop3': 'gfx12_asm_vop3.s', 'vop3_aliases': 'gfx12_asm_vop3_aliases.s', 'vop3c': 'gfx12_asm_vop3c.s', 'vop3cx': 'gfx12_asm_vop3cx.s',
'vop3p': 'gfx12_asm_vop3p.s', 'vop3p_aliases': 'gfx12_asm_vop3p_aliases.s', 'vop3p_features': 'gfx12_asm_vop3p_features.s',
'vopd': 'gfx12_asm_vopd.s', 'vopd_features': 'gfx12_asm_vopd_features.s',
'vop3_from_vop1': 'gfx12_asm_vop3_from_vop1.s', 'vop3_from_vop2': 'gfx12_asm_vop3_from_vop2.s',
'ds': 'gfx12_asm_ds.s', 'ds_alias': 'gfx12_asm_ds_alias.s', 'smem': 'gfx12_asm_smem.s',
'vflat': 'gfx12_asm_vflat.s', 'vflat_alias': 'gfx12_asm_vflat_alias.s',
'vglobal': 'gfx12_asm_vflat.s', 'vglobal_alias': 'gfx12_asm_vflat_alias.s', # global instructions in vflat files
'vscratch': 'gfx12_asm_vflat.s', # scratch instructions in vflat file
'vbuffer_mubuf': 'gfx12_asm_vbuffer_mubuf.s', 'vbuffer_mubuf_alias': 'gfx12_asm_vbuffer_mubuf_alias.s',
'vbuffer_mtbuf': 'gfx12_asm_vbuffer_mtbuf.s', 'vbuffer_mtbuf_alias': 'gfx12_asm_vbuffer_mtbuf_alias.s',
'vimage': 'gfx12_asm_vimage.s', 'vimage_alias': 'gfx12_asm_vimage_alias.s', 'vsample': 'gfx12_asm_vsample.s',
'vdsdir': 'gfx12_asm_vdsdir.s', 'vdsdir_alias': 'gfx12_asm_vdsdir_alias.s',
'exp': 'gfx12_asm_exp.s', 'wmma_w32': 'gfx12_asm_wmma_w32.s', 'wmma_w64': 'gfx12_asm_wmma_w64.s',
'global_load_tr': 'gfx12_asm_global_load_tr.s',
# NOTE: 'features' (gfx12_asm_features.s) tests DPP instruction variants which require separate format decoders
}
def parse_llvm_tests(text: str, gfx_prefix: str) -> list[tuple[str, bytes]]:
"""Parse LLVM test format into (asm, expected_bytes) pairs."""
tests, lines = [], text.split('\n')
pattern = rf'(?:{gfx_prefix}|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]'
pattern2 = rf'(?:{gfx_prefix}|W32|W64)[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]'
for i, line in enumerate(lines):
line = line.strip()
if not line or line.startswith(('//', '.', ';')): continue
asm_text = line.split('//')[0].strip()
if not asm_text: continue
for j in range(i, min(i + 3, len(lines))):
# Match GFX11, W32, or W64 encodings (all valid for gfx11)
# Format 1: "// GFX11: v_foo ... ; encoding: [0x01,0x02,...]"
# Format 2: "// GFX11: [0x01,0x02,...]" (used by DS, older files)
if m := re.search(r'(?:GFX11|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
if m := re.search(pattern, lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
elif m := re.search(r'(?:GFX11|W32|W64)[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
elif m := re.search(pattern2, lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
else:
continue
else: continue
if hex_bytes:
try: tests.append((asm_text, bytes.fromhex(hex_bytes)))
except ValueError: pass
break
return tests
def try_assemble(text: str):
"""Try to assemble instruction text, return bytes or None on failure."""
try: return asm(text).to_bytes()
except: return None
def compile_asm_batch(instrs: list[str]) -> list[bytes]:
def compile_asm_batch(instrs: list[str], mcpu: str, mattr: str = '+real-true16,+wavefrontsize32') -> list[bytes]:
"""Compile multiple instructions with a single llvm-mc call."""
if not instrs: return []
asm_text = ".text\n" + "\n".join(instrs) + "\n"
result = subprocess.run(
[get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
input=asm_text, capture_output=True, text=True, timeout=30)
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
# Parse all encodings from output
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', f'-mattr={mattr}', '-show-encoding'],
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True, timeout=30)
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
results = []
for line in result.stdout.split('\n'):
if 'encoding:' not in line: continue
@ -102,100 +80,127 @@ def compile_asm_batch(instrs: list[str]) -> list[bytes]:
if len(results) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(results)}")
return results
class TestLLVM(unittest.TestCase):
"""Test assembler and disassembler against all LLVM test vectors."""
def matches_encoding(data: bytes, fmt) -> bool:
"""Check if instruction bytes match format's expected encoding bits."""
if not hasattr(fmt, '_encoding') or fmt._encoding is None: return True
bf, expected = fmt._encoding
val = int.from_bytes(data[:fmt._size()], 'little')
return ((val >> bf.lo) & bf.mask()) == expected
class TestLLVMBase(unittest.TestCase):
"""Base class for LLVM assembler tests."""
tests: dict[str, list[tuple[str, bytes]]] = {}
formats: dict[str, type] = {}
gfx_prefix: str = ""
mcpu: str = ""
arch_name: str = ""
@classmethod
def setUpClass(cls):
for name, (filename, _, _) in LLVM_TEST_FILES.items():
def _load_tests(cls, test_files: dict[str, str]):
for name, filename in test_files.items():
try:
data = fetch(f"{LLVM_BASE}/{filename}").read_bytes()
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'))
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'), cls.gfx_prefix)
except Exception as e:
print(f"Warning: couldn't fetch {filename}: {e}")
cls.tests[name] = []
# Generate test methods dynamically for each format
def _make_asm_test(name):
def test(self):
passed, failed, skipped = 0, 0, 0
for asm_text, expected in self.tests.get(name, []):
result = try_assemble(asm_text)
if result is None: skipped += 1
elif result == expected: passed += 1
else: failed += 1
print(f"{name.upper()} asm: {passed} passed, {failed} failed, {skipped} skipped")
self.assertEqual(failed, 0)
return test
def _test_disasm(self, name: str):
"""Test decoding instructions and verify disassembly produces correct bytes."""
if name not in self.tests or not self.tests[name]: self.skipTest(f"No test data for {name}")
fmt_cls = self.formats.get(name)
if fmt_cls is None: self.skipTest(f"No format class for {name}")
def _make_disasm_test(name):
def test(self):
_, fmt_cls, op_enum = LLVM_TEST_FILES[name]
# VOP3SD opcodes that share encoding with VOP3 (only for vop3sd test, not vopc promotions)
vop3sd_opcodes = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
is_vopc_promotion = name in ('vop3_from_vopc', 'vop3_from_vopcx')
undocumented = {'smem': {34, 35}, 'sopk': {22, 23}, 'sopp': {8, 58, 59}}
# Determine wave size from test name (w64 = wave64, otherwise wave32)
wave_size = 64 if 'w64' in name else 32
mattr = f'+real-true16,+wavefrontsize{wave_size}'
# First pass: decode all instructions and collect disasm strings
to_test: list[tuple[str, bytes, str | None, str | None]] = [] # (asm_text, data, disasm_str, error)
skipped = 0
to_test: list[tuple[str, bytes, str | None, str | None]] = []
for asm_text, data in self.tests.get(name, []):
if len(data) > fmt_cls._size(): continue
temp_inst = fmt_cls.from_bytes(data)
temp_op = temp_inst._values.get('op', 0)
temp_op = temp_op.val if hasattr(temp_op, 'val') else temp_op
if temp_op in undocumented.get(name, set()): skipped += 1; continue
if name == 'sopp':
simm16 = temp_inst._values.get('simm16', 0)
simm16 = simm16.val if hasattr(simm16, 'val') else simm16
sopp_no_imm = {48, 54, 53, 55, 60, 61, 62}
if temp_op in sopp_no_imm and simm16 != 0: skipped += 1; continue
if not matches_encoding(data, fmt_cls): continue
try:
if fmt_cls.__name__ in ('VOP3', 'VOP3SD'):
temp = VOP3.from_bytes(data)
op_val = temp._values.get('op', 0)
op_val = op_val.val if hasattr(op_val, 'val') else op_val
is_vop3sd = (op_val in vop3sd_opcodes) and not is_vopc_promotion
decoded = VOP3SD.from_bytes(data) if is_vop3sd else VOP3.from_bytes(data)
if is_vop3sd: VOP3SDOp(op_val)
else: VOP3Op(op_val)
else:
decoded = fmt_cls.from_bytes(data)
op_val = decoded._values.get('op', 0)
op_val = op_val.val if hasattr(op_val, 'val') else op_val
op_enum(op_val)
decoded = fmt_cls.from_bytes(data)
if decoded.to_bytes()[:len(data)] != data:
to_test.append((asm_text, data, None, "decode roundtrip failed"))
continue
to_test.append((asm_text, data, decoded.disasm(), None))
to_test.append((asm_text, data, decoded.disasm(wave_size), None))
except Exception as e:
to_test.append((asm_text, data, None, f"exception: {e}"))
# Batch compile all disasm strings with single llvm-mc call
disasm_strs = [(i, t[2]) for i, t in enumerate(to_test) if t[2] is not None]
llvm_results = compile_asm_batch([s for _, s in disasm_strs]) if disasm_strs else []
llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)}
llvm_map = {}
if disasm_strs:
llvm_results = compile_asm_batch([s for _, s in disasm_strs], self.mcpu, mattr)
llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)}
# Match results back
passed, failed = 0, 0
failures: list[str] = []
passed, failed, failures = 0, 0, []
for idx, (asm_text, data, disasm_str, error) in enumerate(to_test):
if error:
failed += 1; failures.append(f"{error} for {data.hex()}")
elif disasm_str is not None and idx in llvm_map:
llvm_bytes = llvm_map[idx]
if llvm_bytes is not None and llvm_bytes == data: passed += 1
elif llvm_bytes is not None: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
if llvm_bytes == data: passed += 1
else: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
print(f"{name.upper()} disasm: {passed} passed, {failed} failed" + (f", {skipped} skipped" if skipped else ""))
if failures[:10]: print(" " + "\n ".join(failures[:10]))
self.assertEqual(failed, 0)
print(f"{self.arch_name} {name.upper()} disasm: {passed} passed, {failed} failed")
if failures[:5]: print(" " + "\n ".join(failures[:5]))
self.assertGreater(passed, 0, f"No tests passed for {name}")
class TestLLVMRDNA3(TestLLVMBase):
"""Test RDNA3 assembler against LLVM test vectors."""
gfx_prefix, mcpu, arch_name = "GFX11", "gfx1100", "RDNA3"
@classmethod
def setUpClass(cls):
from extra.assembly.amd.autogen.rdna3.ins import SOP1, SOP2, SOPC, SOPK, SOPP, VOP1, VOP2, VOP3, VOP3P, VOPC, VOPD, VINTERP, DS, SMEM, FLAT, MUBUF, MTBUF, MIMG, LDSDIR, EXP
cls.formats = {
'sop1': SOP1, 'sop2': SOP2, 'sopc': SOPC, 'sopk': SOPK, 'sopp': SOPP,
'vop1': VOP1, 'vop2': VOP2, 'vopc': VOPC, 'vopcx': VOPC, 'vop3': VOP3, 'vop3p': VOP3P,
'vinterp': VINTERP, 'vopd': VOPD, 'ds': DS, 'smem': SMEM, 'flat': FLAT,
'mubuf': MUBUF, 'mtbuf': MTBUF, 'mimg': MIMG, 'mimg_features': MIMG, 'wmma': VOP3P, 'ldsdir': LDSDIR, 'exp': EXP,
'vop3_from_vop1': VOP3, 'vop3_from_vop2': VOP3, 'vop3_from_vopc': VOP3, 'vop3_from_vopcx': VOP3,
'vop3_features': VOP3, 'vop3p_features': VOP3P, 'vopd_features': VOPD,
'vop3_alias': VOP3, 'vop3p_alias': VOP3P, 'vopc_alias': VOPC, 'vopcx_alias': VOPC,
'vinterp_alias': VINTERP, 'smem_alias': SMEM, 'mubuf_alias': MUBUF, 'mtbuf_alias': MTBUF,
}
cls._load_tests(RDNA3_TEST_FILES)
class TestLLVMRDNA4(TestLLVMBase):
"""Test RDNA4 assembler against LLVM test vectors."""
gfx_prefix, mcpu, arch_name = "GFX12", "gfx1200", "RDNA4"
@classmethod
def setUpClass(cls):
import extra.assembly.amd.autogen.rdna4.ins as rdna4
get = lambda n: getattr(rdna4, n, None)
cls.formats = {
'sop1': get('SOP1'), 'sop2': get('SOP2'), 'sop2_alias': get('SOP2'), 'sopc': get('SOPC'),
'sopk': get('SOPK'), 'sopk_alias': get('SOPK'), 'sopp': get('SOPP'),
'vop1': get('VOP1'), 'vop2': get('VOP2'), 'vop2_aliases': get('VOP2'), 'vopc': get('VOPC'), 'vopcx': get('VOPC'),
'vop3': get('VOP3'), 'vop3_aliases': get('VOP3'), 'vop3c': get('VOP3'), 'vop3cx': get('VOP3'),
'vop3p': get('VOP3P'), 'vop3p_aliases': get('VOP3P'), 'vop3p_features': get('VOP3P'),
'vopd': get('VOPD'), 'vopd_features': get('VOPD'),
'vop3_from_vop1': get('VOP3'), 'vop3_from_vop2': get('VOP3'),
'ds': get('VDS'), 'ds_alias': get('VDS'), 'smem': get('SMEM'), 'vinterp': get('VINTERP'), 'exp': get('VEXPORT'),
'vbuffer_mubuf': get('VBUFFER'), 'vbuffer_mubuf_alias': get('VBUFFER'),
'vbuffer_mtbuf': get('VBUFFER'), 'vbuffer_mtbuf_alias': get('VBUFFER'),
'vdsdir': get('VDSDIR'), 'vdsdir_alias': get('VDSDIR'),
'vflat': get('VFLAT'), 'vflat_alias': get('VFLAT'),
'vglobal': get('VGLOBAL'), 'vglobal_alias': get('VGLOBAL'),
'vscratch': get('VSCRATCH'),
'vimage': get('VIMAGE'), 'vimage_alias': get('VIMAGE'), 'vsample': get('VSAMPLE'),
'wmma_w32': get('VOP3P'), 'wmma_w64': get('VOP3P'),
'global_load_tr': get('VGLOBAL'),
}
cls._load_tests(RDNA4_TEST_FILES)
# Generate test methods dynamically
def _make_test(name):
def test(self): self._test_disasm(name)
return test
for name in LLVM_TEST_FILES:
setattr(TestLLVM, f'test_{name}_asm', _make_asm_test(name))
setattr(TestLLVM, f'test_{name}_disasm', _make_disasm_test(name))
for name in RDNA3_TEST_FILES: setattr(TestLLVMRDNA3, f'test_{name}_disasm', _make_test(name))
for name in RDNA4_TEST_FILES: setattr(TestLLVMRDNA4, f'test_{name}_disasm', _make_test(name))
if __name__ == "__main__":
unittest.main()
if __name__ == "__main__": unittest.main()

View file

@ -1,90 +1,37 @@
#!/usr/bin/env python3
"""Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match."""
import unittest, io, sys, re, subprocess, os
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.dsl import Inst
from extra.assembly.amd.asm import asm
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump
def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
"""Disassemble ELF binary and return list of (instruction_text, machine_code_bytes)."""
old_stdout = sys.stdout
sys.stdout = io.StringIO()
compiler.disassemble(lib)
output = sys.stdout.getvalue()
sys.stdout = old_stdout
results = []
for line in output.splitlines():
if '//' not in line: continue
instr = line.split('//')[0].strip()
if not instr: continue
comment = line.split('//')[1].strip()
if ':' not in comment: continue
hex_str = comment.split(':')[1].strip().split()[0]
try:
machine_bytes = bytes.fromhex(hex_str)[::-1] # big-endian to little-endian
results.append((instr, machine_bytes))
except ValueError:
continue
return results
def compile_asm(instr: str, compiler=None) -> bytes:
"""Compile a single instruction with llvm-mc and return the machine code bytes."""
llvm_mc = get_llvm_mc()
result = subprocess.run(
[llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
input=f".text\n{instr}\n", capture_output=True, text=True)
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed for '{instr}': {result.stderr.strip()}")
# Parse encoding: [0x01,0x39,0x0a,0x7e]
for line in result.stdout.split('\n'):
if 'encoding:' in line:
enc = line.split('encoding:')[1].strip()
if enc.startswith('[') and enc.endswith(']'):
hex_vals = enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')
return bytes.fromhex(hex_vals)
raise RuntimeError(f"no encoding found in llvm-mc output for: {instr}")
def compile_asm_batch(instrs: list[str]) -> list[bytes]:
def compile_asm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[bytes]:
"""Compile multiple instructions with a single llvm-mc call."""
if not instrs: return []
llvm_mc = get_llvm_mc()
src = ".text\n" + "\n".join(instrs) + "\n"
result = subprocess.run(
[llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
input=src, capture_output=True, text=True)
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True)
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
# Parse all encodings in order
encodings = []
for line in result.stdout.split('\n'):
if 'encoding:' in line:
enc = line.split('encoding:')[1].strip()
if enc.startswith('[') and enc.endswith(']'):
hex_vals = enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')
encodings.append(bytes.fromhex(hex_vals))
encodings.append(bytes.fromhex(enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')))
if len(encodings) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(encodings)}")
return encodings
def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str]:
def compile_and_disasm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[str]:
"""Compile instructions with LLVM and get LLVM's disassembly."""
import tempfile, os
import tempfile
if not instrs: return []
# Build assembly source with all instructions
src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n"
src += "\n".join(f" {instr}" for instr in instrs) + "\n"
# Use llvm-mc to assemble to object file
src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n" + "\n".join(f" {instr}" for instr in instrs) + "\n"
with tempfile.NamedTemporaryFile(suffix='.o', delete=False) as f:
obj_path = f.name
try:
result = subprocess.run(
[get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path],
input=src, capture_output=True, text=True)
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path],
input=src, capture_output=True, text=True)
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
# Disassemble with llvm-objdump
result = subprocess.run([get_llvm_objdump(), '-d', '--mcpu=gfx1100', obj_path], capture_output=True, text=True)
result = subprocess.run([get_llvm_objdump(), '-d', f'--mcpu={mcpu}', obj_path], capture_output=True, text=True)
if result.returncode != 0: raise RuntimeError(f"llvm-objdump failed: {result.stderr.strip()}")
# Parse disassembly output
results: list[str] = []
for line in result.stdout.splitlines():
if '//' not in line: continue
@ -94,127 +41,143 @@ def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str]:
finally:
os.unlink(obj_path)
class TestTinygradKernelRoundtrip(unittest.TestCase):
"""Test roundtrip on real tinygrad-generated kernels using get_kernels_from_tinygrad pattern."""
class TestRoundtripBase(unittest.TestCase):
"""Base class for roundtrip tests."""
mcpu: str = 'gfx1100'
arch: str = 'rdna3'
@classmethod
def _get_modules(cls):
if cls.arch == 'rdna3':
from extra.assembly.amd.autogen.rdna3 import ins
from extra.assembly.amd.asm import detect_format, asm
else:
import extra.assembly.amd.autogen.rdna4.ins as ins
from extra.assembly.amd.asm import asm
detect_format = None # RDNA4 uses different detection
return ins, detect_format, asm
def _test_kernel_roundtrip(self, op_fn):
"""Generate kernel from op_fn, test:
1. decode -> reencode matches original bytes
2. asm(disasm()) matches LLVM output
3. our disasm() matches LLVM's disassembly string exactly
"""
"""Generate kernel from op_fn, test decode -> reencode and asm(disasm()) matches LLVM."""
from extra.assembly.amd.test.test_compare_emulators import get_kernels_from_tinygrad
from tinygrad.runtime.support.compiler_amd import HIPCompiler
ins, detect_format, asm = self._get_modules()
kernels, _, _ = get_kernels_from_tinygrad(op_fn)
compiler = HIPCompiler('gfx1100')
compiler = HIPCompiler(self.mcpu)
# First pass: decode all instructions and collect info
decoded_instrs: list[tuple] = [] # list of (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err)
# First pass: decode all instructions
decoded_instrs: list[tuple] = []
for ki, kernel in enumerate(kernels):
offset = 0
while offset < len(kernel.code):
remaining = kernel.code[offset:]
fmt = detect_format(remaining)
if fmt is None:
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
offset += 4
continue
if len(remaining) < 4: break
# Try to detect format
if detect_format is not None:
try:
fmt = detect_format(remaining)
except ValueError:
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
offset += 4
continue
else:
# For RDNA4, try formats in order
fmt = None
from extra.assembly.amd.autogen.rdna4.ins import SOP1, SOP2, SOPC, SOPK, SOPP, VOP1, VOP2, VOP3, VOP3P, VOPC, VOPD, VDS, SMEM, VFLAT, VBUFFER, VIMAGE, VSAMPLE, VEXPORT, VDSDIR
word = int.from_bytes(remaining[:4], 'little')
for cls in [VOPD, VOP3P, VOP3, VDS, VFLAT, VBUFFER, VIMAGE, VSAMPLE, SMEM, VEXPORT, SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2, VDSDIR]:
if cls._encoding is not None:
bf, val = cls._encoding
if ((word >> bf.lo) & bf.mask()) == val:
fmt = cls
break
if fmt is None:
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
offset += 4
continue
base_size = fmt._size()
if len(remaining) < base_size:
break
if len(remaining) < base_size: break
try:
decoded = fmt.from_bytes(remaining) # pass all remaining bytes so from_bytes can read literal
size = decoded.size() # actual size including literal
decoded = fmt.from_bytes(remaining)
size = decoded.size()
orig_bytes = remaining[:size]
reencoded = decoded.to_bytes()
our_disasm = decoded.disasm()
decode_ok = reencoded == orig_bytes
decode_err: str | None = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}"
decode_err = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}"
decoded_instrs.append((ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err))
except Exception as e:
decoded_instrs.append((ki, offset, remaining[:base_size], None, None, False, str(e)))
size = base_size
offset += size
# Collect disasm strings for batched LLVM calls - skip unknown opcodes (op_X) that LLVM can't compile
asm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for asm test
disasm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for disasm comparison test
# Collect disasm strings for batched LLVM calls
asm_test_instrs: list[tuple[int, str]] = []
for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs):
if our_disasm is None: continue
# Skip unknown opcodes and malformed instructions for both tests
if our_disasm.startswith('op_') or re.search(r', \d+, \d+, \d+,', our_disasm): continue
asm_test_instrs.append((idx, our_disasm))
disasm_test_instrs.append((idx, our_disasm))
# Batch compile for asm test
asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs])
asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs], self.mcpu)
asm_llvm_map = {idx: result for (idx, _), result in zip(asm_test_instrs, asm_llvm_results)}
# Batch compile+disasm for disasm comparison test
disasm_llvm_results = compile_and_disasm_batch([d for _, d in disasm_test_instrs], compiler)
disasm_llvm_map = {idx: result for (idx, _), result in zip(disasm_test_instrs, disasm_llvm_results)}
disasm_llvm_results = compile_and_disasm_batch([d for _, d in asm_test_instrs], self.mcpu)
disasm_llvm_map = {idx: result for (idx, _), result in zip(asm_test_instrs, disasm_llvm_results)}
# Now evaluate results
# Evaluate results
decode_passed, decode_failed, decode_skipped = 0, 0, 0
asm_passed, asm_failed, asm_skipped = 0, 0, 0
disasm_passed, disasm_failed, disasm_skipped = 0, 0, 0
decode_failures: list[str] = []
asm_failures: list[str] = []
disasm_failures: list[str] = []
decode_failures, asm_failures, disasm_failures = [], [], []
for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs):
# Decode test
if decode_ok:
decode_passed += 1
elif decode_err == "no format":
decode_skipped += 1
if decode_ok: decode_passed += 1
elif decode_err == "no format": decode_skipped += 1
else:
decode_failed += 1
decode_failures.append(f"K{ki}@{offset}: {our_disasm}: {decode_err}")
# Asm test
if our_disasm is None:
asm_skipped += 1
disasm_skipped += 1
elif idx in asm_llvm_map:
llvm_bytes = asm_llvm_map[idx]
try:
our_bytes = asm(our_disasm).to_bytes()
if our_bytes[:len(llvm_bytes)] == llvm_bytes:
asm_passed += 1
if our_bytes[:len(llvm_bytes)] == llvm_bytes: asm_passed += 1
else:
asm_failed += 1
asm_failures.append(f"K{ki}@{offset}: '{our_disasm}': ours={our_bytes[:len(llvm_bytes)].hex()} llvm={llvm_bytes.hex()}")
except Exception:
asm_skipped += 1
if idx in disasm_llvm_map:
if our_disasm == disasm_llvm_map[idx]: disasm_passed += 1
else:
disasm_failed += 1
disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{disasm_llvm_map[idx]}'")
else:
disasm_skipped += 1
else:
asm_skipped += 1
# Disasm comparison test
if our_disasm is None:
disasm_skipped += 1
elif idx in disasm_llvm_map:
llvm_disasm = disasm_llvm_map[idx]
if our_disasm == llvm_disasm:
disasm_passed += 1
else:
disasm_failed += 1
disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{llvm_disasm}'")
else:
disasm_skipped += 1
print(f"decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped")
print(f"asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
print(f"disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped")
print(f"{self.arch.upper()} decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped")
print(f"{self.arch.upper()} asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
print(f"{self.arch.upper()} disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped")
self.assertEqual(decode_failed, 0, f"Decode failures:\n" + "\n".join(decode_failures[:20]))
self.assertEqual(asm_failed, 0, f"Asm failures:\n" + "\n".join(asm_failures[:20]))
# Note: disasm string comparison is informational only - formatting differences between LLVM versions are expected
# Basic unary ops
class TestRoundtripRDNA3(TestRoundtripBase):
"""Roundtrip tests for RDNA3 (gfx1100)."""
mcpu, arch = 'gfx1100', 'rdna3'
def test_neg(self): self._test_kernel_roundtrip(lambda T: -T([1.0, -2.0, 3.0, -4.0]))
def test_relu(self): self._test_kernel_roundtrip(lambda T: T([-1.0, 0.0, 1.0, 2.0]).relu())
def test_exp(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).exp())
@ -222,42 +185,62 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
def test_sin(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).sin())
def test_sqrt(self): self._test_kernel_roundtrip(lambda T: T([1.0, 4.0, 9.0]).sqrt())
def test_recip(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0, 4.0]).reciprocal())
# Binary ops
def test_add(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0]) + T([3.0, 4.0]))
def test_sub(self): self._test_kernel_roundtrip(lambda T: T([5.0, 6.0]) - T([1.0, 2.0]))
def test_mul(self): self._test_kernel_roundtrip(lambda T: T([2.0, 3.0]) * T([4.0, 5.0]))
def test_div(self): self._test_kernel_roundtrip(lambda T: T([10.0, 20.0]) / T([2.0, 4.0]))
def test_max_binary(self): self._test_kernel_roundtrip(lambda T: T([1.0, 5.0]).maximum(T([3.0, 2.0])))
# Reductions
def test_sum_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).sum())
def test_max_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).max())
def test_mean_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(32).mean())
# Matmul
def test_gemm_4x4(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4) @ T.empty(4, 4))
def test_gemv(self): self._test_kernel_roundtrip(lambda T: T.empty(1, 16) @ T.empty(16, 16))
# Complex ops
def test_softmax(self): self._test_kernel_roundtrip(lambda T: T.empty(16).softmax())
def test_layernorm(self): self._test_kernel_roundtrip(lambda T: T.empty(8, 8).layernorm())
# Memory patterns
def test_contiguous(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4).permute(1, 0).contiguous())
def test_reshape(self): self._test_kernel_roundtrip(lambda T: (T.empty(16) + 1).reshape(4, 4).contiguous())
def test_expand(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 1).expand(4, 4).contiguous())
# Cast ops
def test_cast_int(self): self._test_kernel_roundtrip(lambda T: T.empty(16).int().float())
def test_cast_half(self): self._test_kernel_roundtrip(lambda T: T.empty(16).half().float())
# Comparison ops
def test_cmp_lt(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) < T.empty(64)).where(T.empty(64), T.empty(64)))
def test_where(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) > 0).where(T.empty(64), T.empty(64)))
# Fused ops
def test_fma(self): self._test_kernel_roundtrip(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0])))
@unittest.skipUnless(os.environ.get("TEST_RDNA4"), "RDNA4 roundtrip tests require TEST_RDNA4=1 and gfx1200 hardware")
class TestRoundtripRDNA4(TestRoundtripBase):
"""Roundtrip tests for RDNA4 (gfx1200)."""
mcpu, arch = 'gfx1200', 'rdna4'
def test_neg(self): self._test_kernel_roundtrip(lambda T: -T([1.0, -2.0, 3.0, -4.0]))
def test_relu(self): self._test_kernel_roundtrip(lambda T: T([-1.0, 0.0, 1.0, 2.0]).relu())
def test_exp(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).exp())
def test_log(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0, 3.0]).log())
def test_sin(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).sin())
def test_sqrt(self): self._test_kernel_roundtrip(lambda T: T([1.0, 4.0, 9.0]).sqrt())
def test_recip(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0, 4.0]).reciprocal())
def test_add(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0]) + T([3.0, 4.0]))
def test_sub(self): self._test_kernel_roundtrip(lambda T: T([5.0, 6.0]) - T([1.0, 2.0]))
def test_mul(self): self._test_kernel_roundtrip(lambda T: T([2.0, 3.0]) * T([4.0, 5.0]))
def test_div(self): self._test_kernel_roundtrip(lambda T: T([10.0, 20.0]) / T([2.0, 4.0]))
def test_max_binary(self): self._test_kernel_roundtrip(lambda T: T([1.0, 5.0]).maximum(T([3.0, 2.0])))
def test_sum_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).sum())
def test_max_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).max())
def test_mean_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(32).mean())
def test_gemm_4x4(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4) @ T.empty(4, 4))
def test_gemv(self): self._test_kernel_roundtrip(lambda T: T.empty(1, 16) @ T.empty(16, 16))
def test_softmax(self): self._test_kernel_roundtrip(lambda T: T.empty(16).softmax())
def test_layernorm(self): self._test_kernel_roundtrip(lambda T: T.empty(8, 8).layernorm())
def test_contiguous(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4).permute(1, 0).contiguous())
def test_reshape(self): self._test_kernel_roundtrip(lambda T: (T.empty(16) + 1).reshape(4, 4).contiguous())
def test_expand(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 1).expand(4, 4).contiguous())
def test_cast_int(self): self._test_kernel_roundtrip(lambda T: T.empty(16).int().float())
def test_cast_half(self): self._test_kernel_roundtrip(lambda T: T.empty(16).half().float())
def test_cmp_lt(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) < T.empty(64)).where(T.empty(64), T.empty(64)))
def test_where(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) > 0).where(T.empty(64), T.empty(64)))
def test_fma(self): self._test_kernel_roundtrip(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0])))
# Keep old class name for backwards compatibility
TestTinygradKernelRoundtrip = TestRoundtripRDNA3
if __name__ == "__main__":
unittest.main()