rdna4 works

This commit is contained in:
George Hotz 2025-12-31 21:20:47 +00:00
commit 9302f38f5b
4 changed files with 225 additions and 141 deletions

View file

@ -1,4 +1,4 @@
# RDNA3 assembler and disassembler
# RDNA3/RDNA4 assembler and disassembler
from __future__ import annotations
import re
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
@ -9,26 +9,21 @@ from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp)
def _matches_encoding(word: int, cls: type[Inst]) -> bool:
"""Check if word matches the encoding pattern of an instruction class."""
if cls._encoding is None: return False
bf, val = cls._encoding
return ((word >> bf.lo) & bf.mask()) == val
# Order matters: more specific encodings first, VOP2 last (it's a catch-all for bit31=0)
_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP]
_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls
_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2]
def detect_format(data: bytes) -> type[Inst]:
"""Detect instruction format from machine code bytes."""
assert len(data) >= 4, f"need at least 4 bytes, got {len(data)}"
word = int.from_bytes(data[:4], 'little')
# Check 64-bit formats first (bits[31:30] == 0b11)
if (word >> 30) == 0b11:
for cls in _FORMATS_64:
if _matches_encoding(word, cls):
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in Inst._VOP3SD_OPS else cls
raise ValueError(f"unknown 64-bit format word={word:#010x}")
# 32-bit formats
for cls in _FORMATS_32:
if _matches_encoding(word, cls): return cls
raise ValueError(f"unknown 32-bit format word={word:#010x}")
@ -41,10 +36,22 @@ HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_H
6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS', 15: 'HW_REG_SH_MEM_BASES', 18: 'HW_REG_PERF_SNAPSHOT_PC_LO',
19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK',
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'}
HWREG_RDNA4 = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 5: 'HW_REG_GPR_ALLOC', 6: 'HW_REG_LDS_ALLOC',
7: 'HW_REG_IB_STS', 23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2'}
HWREG_IDS = {v.lower(): k for k, v in HWREG.items()}
MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA',
131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'}
# RDNA4 cache policy tables
_TH_LOAD = {0: None, 1: 'TH_LOAD_NT', 2: 'TH_LOAD_HT', 3: 'TH_LOAD_LU', 4: 'TH_LOAD_NT_RT', 5: 'TH_LOAD_RT_NT', 6: 'TH_LOAD_NT_HT'}
_TH_STORE = {0: None, 1: 'TH_STORE_NT', 2: 'TH_STORE_HT', 3: 'TH_STORE_LU', 4: 'TH_STORE_NT_RT', 5: 'TH_STORE_RT_NT', 6: 'TH_STORE_NT_HT'}
_TH_ATOMIC = {0: None, 1: 'TH_ATOMIC_NT', 2: 'TH_ATOMIC_RETURN'}
_SCOPE = {0: None, 1: 'SCOPE_SE', 2: 'SCOPE_SA', 3: 'SCOPE_SYS'}
# Export target mapping
_EXP_TARGETS = {**{i: f'mrt{i}' for i in range(8)}, 8: 'mrtz', **{i+12: f'pos{i}' for i in range(5)},
20: 'prim', 21: 'dual_src_blend0', 22: 'dual_src_blend1'}
# ═══════════════════════════════════════════════════════════════════════════════
# HELPERS
# ═══════════════════════════════════════════════════════════════════════════════
@ -77,12 +84,11 @@ def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int:
def _has(op: str, *subs) -> bool: return any(s in op for s in subs)
def _omod(v: int) -> str: return {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(v, "")
def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v) # format 16-bit src: vgpr.h/l or literal
def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v)
def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c)
def _fmt_bits(label: str, val: int, count: int) -> str: return f"{label}:[{','.join(str((val >> i) & 1) for i in range(count))}]"
def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any_hi: bool) -> str:
"""Format VOP3 source operand with modifiers."""
if n > 1: s = _fmt_src(v, n)
elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else (f"v{v - 256}.l" if any_hi else inst.lit(v))
else: s = inst.lit(v)
@ -90,12 +96,37 @@ def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any
return f"-{s}" if neg else s
def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
"""Format op_sel modifier string."""
if not need: return ""
if is16_d and (opsel & 8): return f" op_sel:[1,1,1{',1' if n == 3 else ''}]"
if n == 3: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{(opsel >> 3) & 1}]"
return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1}]"
def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
base = [1, 2, 3, 3, 2, 3, 3, 4][dim] # 1d,2d,3d,cube,1d_arr,2d_arr,2d_msaa,2d_msaa_arr
grad = [1, 2, 3, 2, 1, 2, 2, 2][dim]
if 'get_resinfo' in name: return 1
packed, unpacked = 0, 0
if '_mip' in name: packed += 1
elif 'sample' in name or 'gather' in name:
if '_o' in name: unpacked += 1
if re.search(r'_c(_|$)', name): unpacked += 1
if '_d' in name: unpacked += (grad + 1) & ~1 if '_g16' in name else grad*2
if '_b' in name: unpacked += 1
if '_l' in name and '_cl' not in name and '_lz' not in name: packed += 1
if '_cl' in name: packed += 1
return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked
def _collect_vaddrs(inst, count: int) -> list[int]:
vaddrs = [inst.vaddr0]
if count > 1: vaddrs.append(inst.vaddr1)
if count > 2: vaddrs.append(inst.vaddr2)
if count > 3: vaddrs.append(inst.vaddr3)
if count > 4 and hasattr(inst, 'vaddr4'): vaddrs.append(inst.vaddr4)
return vaddrs[:count]
def _fmt_vaddr_nsa(vaddrs: list[int]) -> str:
return f"v{vaddrs[0]}" if len(vaddrs) == 1 else "[" + ", ".join(f"v{v}" for v in vaddrs) + "]"
# ═══════════════════════════════════════════════════════════════════════════════
# DISASSEMBLER
# ═══════════════════════════════════════════════════════════════════════════════
@ -104,7 +135,6 @@ def _disasm_vop1(inst: VOP1) -> str:
name = inst.op_name.lower()
if inst.op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name
if inst.op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}"
# 16-bit dst: uses .h/.l suffix (determined by name pattern, not dtype - e.g. sat_pk_u8_i16 outputs 8-bit but uses 16-bit encoding)
parts = name.split('_')
is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
@ -113,11 +143,9 @@ def _disasm_vop1(inst: VOP1) -> str:
def _disasm_vop2(inst: VOP2) -> str:
name = inst.op_name.lower()
# RDNA4 removed V_DOT2ACC_F32_F16, check if op exists
try: is_dot2acc = inst.op == VOP2Op.V_DOT2ACC_F32_F16
except ValueError: is_dot2acc = False
suf = "" if is_dot2acc else "_e32"
# fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1
try:
if inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}"
if inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}"
@ -126,7 +154,6 @@ def _disasm_vop2(inst: VOP2) -> str:
if inst.op == VOP2Op.V_CNDMASK_B32: return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, vcc_lo"
except ValueError: pass
if inst.is_16bit(): return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
# Handle 64-bit VOP2 instructions (RDNA4)
dn, sn0, sn1 = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1)
dst = _vreg(inst.vdst, dn) if dn > 1 else f"v{inst.vdst}"
src0 = _fmt_src(inst.src0, sn0) if sn0 > 1 else inst.lit(inst.src0)
@ -144,13 +171,15 @@ NO_ARG_SOPP = {SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICAC
def _disasm_sopp(inst: SOPP) -> str:
name = inst.op_name.lower()
if not name: raise ValueError(f"undefined SOPP op: {inst.op}")
if inst.op in NO_ARG_SOPP: return name
if inst.op == SOPPOp.S_WAITCNT:
vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f
p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
if inst.op == SOPPOp.S_DELAY_ALU:
deps, skips = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3'], ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4']
deps = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3']
skips = ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4']
id0, skip, id1 = inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x7, (inst.simm16 >> 7) & 0xf
dep = lambda v: deps[v-1] if 0 < v <= len(deps) else str(v)
p = [f"instid0({dep(id0)})" if id0 else "", f"instskip({skips[skip]})" if skip else "", f"instid1({dep(id1)})" if id1 else ""]
@ -159,6 +188,7 @@ def _disasm_sopp(inst: SOPP) -> str:
def _disasm_smem(inst: SMEM) -> str:
name = inst.op_name.lower()
if 'rdna4' in inst.__class__.__module__: return _disasm_smem_rdna4(inst)
if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset)
sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op.value <= 12 or name == 's_atc_probe_buffer') else 2
@ -166,6 +196,33 @@ def _disasm_smem(inst: SMEM) -> str:
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}"
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
def _disasm_smem_rdna4(inst) -> str:
name = inst.op_name.lower()
op_val = inst._values.get('op')
if not name:
name = {34: 's_atc_probe', 35: 's_atc_probe_buffer', 32: 's_gl1_inv'}.get(op_val, f's_smem_op{op_val}')
if name in ('s_gl1_inv', 's_dcache_inv'): return name
sbase_idx, sbase_count = inst.sbase * 2, 4 if 'buffer' in name else 2
if sbase_idx == 106: sbase_str = "vcc"
elif 108 <= sbase_idx <= 123: sbase_str = _reg("ttmp", sbase_idx - 108, sbase_count)
else: sbase_str = _sreg(sbase_idx, sbase_count)
ioffset = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000
off_str = f"0x{ioffset:x}" if ioffset >= 0 else f"-0x{-ioffset:x}"
soffset_str = decode_src(inst.soffset)
th_names = ['','TH_LOAD_NT','TH_LOAD_HT','TH_LOAD_LU','TH_LOAD_NT_RT','TH_LOAD_NT_HT','TH_LOAD_BYPASS']
scope_names = ['','SCOPE_SE','SCOPE_DEV','SCOPE_SYS']
if 'prefetch' in name:
return f"{name} {off_str}, {soffset_str}, {inst.sdata}" if 'pc_rel' in name else f"{name} {sbase_str}, {off_str}, {soffset_str}, {inst.sdata}"
if 'atc_probe' in name:
return f"{name} {inst.sdata}, {sbase_str}, {soffset_str}" + (f" offset:{off_str}" if ioffset else "")
if inst.soffset == 124: base_str = f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_str}"
elif ioffset: base_str = f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {soffset_str} offset:{off_str}"
else: base_str = f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {soffset_str}"
mods = []
if inst.th and inst.th < len(th_names) and th_names[inst.th]: mods.append(f"th:{th_names[inst.th]}")
if inst.scope and inst.scope < len(scope_names) and scope_names[inst.scope]: mods.append(f"scope:{scope_names[inst.scope]}")
return base_str + (" " + " ".join(mods) if mods else "")
def _disasm_flat(inst: FLAT) -> str:
name = inst.op_name.lower()
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
@ -173,16 +230,13 @@ def _disasm_flat(inst: FLAT) -> str:
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
w = inst.dst_regs() * (2 if 'cmpswap' in name else 1)
mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
# saddr
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
elif inst.saddr == 124: saddr_s = ", off"
elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr)}"
elif inst.saddr in SPECIAL_PAIRS: saddr_s = f", {SPECIAL_PAIRS[inst.saddr]}"
elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}"
else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr)}"
# addtid: no addr
if 'addtid' in name: return f"{instr} v{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}"
# addr width
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2)
data_s, vdst_s = _vreg(inst.data, w), _vreg(inst.vdst, w // 2 if 'cmpswap' in name else w)
if 'atomic' in name:
@ -197,9 +251,11 @@ def _disasm_ds(inst: DS) -> str:
off2 = f" offset0:{inst.offset0} offset1:{inst.offset1}" if inst.offset0 or inst.offset1 else ""
w = inst.dst_regs()
d0, d1, dst, addr = _vreg(inst.data0, w), _vreg(inst.data1, w), _vreg(inst.vdst, w), f"v{inst.addr}"
if op == DSOp.DS_NOP: return name
if op == DSOp.DS_BVH_STACK_RTN_B32: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}"
if 'bvh_stack_push4_pop1' in name: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}"
if 'bvh_stack_push8_pop1' in name: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 8)}{off}{gds}"
if 'bvh_stack_push8_pop2' in name: return f"{name} {_vreg(inst.vdst, 2)}, {addr}, v{inst.data0}, {_vreg(inst.data1, 8)}{off}{gds}"
if 'gws_sema' in name and op != DSOp.DS_GWS_SEMA_BR: return f"{name}{off}{gds}"
if 'gws_' in name: return f"{name} {addr}{off}{gds}"
if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} v{inst.vdst}{off}{gds}"
@ -220,17 +276,15 @@ def _disasm_ds(inst: DS) -> str:
def _disasm_vop3(inst: VOP3) -> str:
op, name = inst.op, inst.op_name.lower()
# VOP3SD (shared encoding)
if isinstance(op, VOP3SDOp):
if name.startswith('v_s_'):
return f"{name} {_fmt_sdst(inst.vdst, 1)}, {_fmt_src(inst.src0, 1)}"
if hasattr(op, '__class__') and op.__class__.__name__ == 'VOP3SDOp':
sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {srcs}" + _omod(inst.omod)
# Detect 16-bit operand sizes (for .h/.l suffix handling)
is16_d = is16_s = is16_s2 = False
if 'cvt_pk' in name: is16_s = name.endswith('16')
elif m := re.match(r'v_(?:cvt|frexp_exp)_([a-z0-9_]+)_([a-z0-9]+)', name):
@ -239,33 +293,30 @@ def _disasm_vop3(inst: VOP3) -> str:
elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True
elif 'pack_b32' in name: is16_s = is16_s2 = True
else: is16_d = is16_s = is16_s2 = inst.is_16bit()
any_hi = inst.opsel != 0
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s, any_hi)
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s, any_hi)
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2, any_hi)
# Destination
dn = inst.dst_regs()
if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1)
elif dn > 1: dst = _vreg(inst.vdst, dn)
elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" if any_hi else f"v{inst.vdst}"
else: dst = f"v{inst.vdst}"
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
nonvgpr_opsel = (inst.src0 < 256 and (inst.opsel & 1)) or (inst.src1 < 256 and (inst.opsel & 2)) or (inst.src2 < 256 and (inst.opsel & 4))
need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s)
if inst.op < 256: # VOPC
if inst.op < 256:
return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
if inst.op < 384: # VOP2
if inst.op < 384:
n = inst.num_srcs()
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}"
if inst.op < 512: # VOP1
if inst.op < 512:
if _has(name, 'cvt_f32_fp8', 'cvt_f32_bf8'): need_opsel = False
return f"{name}_e64" if op in (VOP3Op.V_NOP, VOP3Op.V_PIPEFLUSH) else f"{name}_e64 {dst}, {s0}{_opsel_str(inst.opsel, 1, need_opsel, is16_d)}{cl}{om}"
# Native VOP3
n = inst.num_srcs()
if 'permlane' in name and '_var' in name: n = 2
if _has(name, 'cvt_sr_fp8', 'cvt_sr_bf8'): n, need_opsel = 2, False
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}"
@ -280,22 +331,49 @@ def _disasm_vop3sd(inst: VOP3SD) -> str:
def _disasm_vopd(inst: VOPD) -> str:
lit = inst._literal or inst.literal
vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), VOPDOp(inst.opx).name.lower(), VOPDOp(inst.opy).name.lower()
if 'rdna4' in inst.__class__.__module__:
import importlib
VOPDOpCls = importlib.import_module('extra.assembly.amd.autogen.rdna4.enum').VOPDOp
else:
VOPDOpCls = VOPDOp
vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), VOPDOpCls(inst.opx).name.lower(), VOPDOpCls(inst.opy).name.lower()
def half(n, vd, s0, vs1): return f"{n} v{vd}, {inst.lit(s0)}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" if 'mov' in n else f"{n} v{vd}, {inst.lit(s0)}, v{vs1}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}"
return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}"
def _disasm_vop3p(inst: VOP3P) -> str:
name = inst.op_name.lower()
is_wmma, n, is_fma_mix = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name
if is_wmma:
sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8
src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8)
else: src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}"
is_wmma, n, is_fma_mix, is_swmmac = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name, 'swmmac' in name
is_rdna4 = 'rdna4' in inst.__class__.__module__
if is_wmma or is_swmmac:
if is_rdna4:
if is_swmmac:
if '16x16x32_iu4' in name: sc0, sc1 = 1, 2
elif '16x16x64_iu4' in name or '16x16x32_iu8' in name or 'fp8' in name or 'bf8' in name: sc0, sc1 = 2, 4
else: sc0, sc1 = 4, 8
sc2 = 1
dst_w = 4 if name.startswith('v_swmmac_f16') or name.startswith('v_swmmac_bf16') else 8
else:
if '16x16x16_iu4' in name: sc0 = 1
elif '16x16x32_iu4' in name or 'iu8' in name or 'fp8' in name or 'bf8' in name: sc0 = 2
else: sc0 = 4
sc1 = sc0
sc2 = 4 if (name.startswith('v_wmma_f16') or name.startswith('v_wmma_bf16')) else 8
dst_w = sc2
src0, src1, src2, dst = _fmt_src(inst.src0, sc0), _fmt_src(inst.src1, sc1), _fmt_src(inst.src2, sc2), _vreg(inst.vdst, dst_w)
else:
sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8
src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8)
else:
src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}"
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
if is_fma_mix:
def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s)
src0, src1, src2 = m(src0, inst.neg & 1, inst.neg_hi & 1), m(src1, inst.neg & 2, inst.neg_hi & 2), m(src2, inst.neg & 4, inst.neg_hi & 4)
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if inst.clmp else [])
elif is_swmmac:
has_index_key = '16x16x64_iu4' not in name
mods = ([f"index_key:{inst.opsel & 1}"] if has_index_key and (inst.opsel & 1) else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else [])
else:
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
@ -313,54 +391,89 @@ def _disasm_buf(inst: MUBUF | MTBUF) -> str:
mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {srsrc}, {decode_src(inst.soffset)}{' ' + ' '.join(mods) if mods else ''}"
def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
"""Calculate vaddr register count for MIMG sample/gather operations."""
# 1d,2d,3d,cube,1d_arr,2d_arr,2d_msaa,2d_msaa_arr
base = [1, 2, 3, 3, 2, 3, 3, 4][dim] # address coords
grad = [1, 2, 3, 2, 1, 2, 2, 2][dim] # gradient coords (for derivatives)
if 'get_resinfo' in name: return 1 # only mip level
packed, unpacked = 0, 0
if '_mip' in name: packed += 1
elif 'sample' in name or 'gather' in name:
if '_o' in name: unpacked += 1 # offset
if re.search(r'_c(_|$)', name): unpacked += 1 # compare (not _cl)
if '_d' in name: unpacked += (grad + 1) & ~1 if '_g16' in name else grad*2 # derivatives
if '_b' in name: unpacked += 1 # bias
if '_l' in name and '_cl' not in name and '_lz' not in name: packed += 1 # LOD
if '_cl' in name: packed += 1 # clamp
return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked
def _disasm_mimg(inst: MIMG) -> str:
name = inst.op_name.lower()
srsrc_base = inst.srsrc * 4
srsrc_str = _sreg_or_ttmp(srsrc_base, 8)
# BVH intersect ray: special case with 4 SGPR srsrc
if 'bvh' in name:
vaddr = (9 if '64' in name else 8) if inst.a16 else (12 if '64' in name else 11)
return f"{name} {_vreg(inst.vdata, 4)}, {_vreg(inst.vaddr, vaddr)}, {_sreg_or_ttmp(srsrc_base, 4)}{' a16' if inst.a16 else ''}"
# vdata width from dmask (gather4/msaa_load always 4), d16 packs, tfe adds 1
vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1)
if inst.d16: vdata = (vdata + 1) // 2
if inst.tfe: vdata += 1
# vaddr width
dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"dim_{inst.dim}"
vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16)
vaddr_str = f"v{inst.vaddr}" if vaddr == 1 else _vreg(inst.vaddr, vaddr)
# modifiers
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask and (inst.dmask != 15 or 'atomic' in name) else []
mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}")
for flag, mod in [(inst.unrm,"unorm"),(inst.glc,"glc"),(inst.slc,"slc"),(inst.dlc,"dlc"),(inst.r128,"r128"),
(inst.a16,"a16"),(inst.tfe,"tfe"),(inst.lwe,"lwe"),(inst.d16,"d16")]:
for flag, mod in [(inst.unrm,"unorm"),(inst.glc,"glc"),(inst.slc,"slc"),(inst.dlc,"dlc"),(inst.r128,"r128"),(inst.a16,"a16"),(inst.tfe,"tfe"),(inst.lwe,"lwe"),(inst.d16,"d16")]:
if flag: mods.append(mod)
# ssamp for sample/gather/get_lod
ssamp_str = ""
if 'sample' in name or 'gather' in name or 'get_lod' in name:
ssamp_str = ", " + _sreg_or_ttmp(inst.ssamp * 4, 4)
ssamp_str = ", " + _sreg_or_ttmp(inst.ssamp * 4, 4) if 'sample' in name or 'gather' in name or 'get_lod' in name else ""
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}"
def _disasm_vsample(inst) -> str:
name = inst.op_name.lower()
if not name: raise ValueError(f"undefined VSAMPLE op: {inst.op}")
if 'msaa_load' in name: raise ValueError(f"image_msaa_load not supported in VSAMPLE for gfx1200")
dim, dim_names = inst.dim, ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
dim_str = dim_names[dim] if dim < len(dim_names) else f"dim_{dim}"
vdata = 4 if 'gather4' in name else (bin(inst.dmask).count('1') or 1)
if inst.d16: vdata = (vdata + 1) // 2
if inst.tfe: vdata += 1
vaddr_count = _mimg_vaddr_width(name, dim, inst.a16)
if vaddr_count > 4: raise ValueError(f"{name} with dim={dim} needs {vaddr_count} vaddrs (>4, unsupported)")
vaddr_str = _fmt_vaddr_nsa(_collect_vaddrs(inst, vaddr_count))
srsrc_str, ssamp_str = _sreg_or_ttmp(inst.rsrc, 8), _sreg_or_ttmp(inst.samp, 4)
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask else []
mods.append(f"dim:SQ_RSRC_IMG_{dim_str.upper()}")
for flag, mod in [(inst.unrm, "unorm"), (inst.r128, "r128"), (inst.a16, "a16"), (inst.tfe, "tfe"), (inst.lwe, "lwe"), (inst.d16, "d16")]:
if flag: mods.append(mod)
th_val, scope_val = inst.th, inst.scope
if th_val == 3 and scope_val == 3: raise ValueError("invalid th/scope: TH_LOAD_LU with SCOPE_SYS")
if scope_val == 2 and th_val == 0: raise ValueError("invalid scope SCOPE_SA without th")
if inst.tfe and inst.d16 and th_val != 0: raise ValueError("invalid th with tfe+d16")
if (th_name := _TH_LOAD.get(th_val)): mods.append(f"th:{th_name}")
if (scope_name := _SCOPE.get(scope_val)): mods.append(f"scope:{scope_name}")
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}, {ssamp_str} {' '.join(mods)}"
def _disasm_vimage(inst) -> str:
name = inst.op_name.lower()
if 'bvh' in name: raise ValueError(f"BVH instruction {name} not supported")
dim, dim_names = inst.dim, ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
dim_str = dim_names[dim] if dim < len(dim_names) else f"dim_{dim}"
is_resinfo, is_atomic, is_store = 'resinfo' in name, 'atomic' in name, 'store' in name
if is_atomic: vdata = (2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)
else: vdata = 4 if 'msaa_load' in name else (bin(inst.dmask).count('1') or 1)
if inst.d16: vdata = (vdata + 1) // 2
if inst.tfe: vdata += 1
if is_resinfo: vaddr_count = 1
else:
base_count = [1, 2, 3, 3, 2, 3, 3, 4][dim] if dim < 8 else 1
total_coords = base_count + (1 if '_mip' in name else 0)
vaddr_count = (total_coords + 1) // 2 if inst.a16 else total_coords
vaddr_str = _fmt_vaddr_nsa(_collect_vaddrs(inst, vaddr_count))
srsrc_str = _sreg_or_ttmp(inst.rsrc, 8)
mods = [f"dmask:0x3" if 'cmpswap' in name else f"dmask:0x1"] if is_atomic else [f"dmask:0x{inst.dmask:x}"]
mods.append(f"dim:SQ_RSRC_IMG_{dim_str.upper()}")
for flag, mod in [(inst.r128, "r128"), (inst.a16, "a16"), (inst.tfe, "tfe"), (inst.d16, "d16")]:
if flag: mods.append(mod)
th_val, scope_val = inst.th, inst.scope
if th_val == 3 and scope_val == 3 and not is_atomic: raise ValueError("invalid th/scope: TH_LOAD_LU with SCOPE_SYS")
if is_atomic and th_val > 2: raise ValueError(f"invalid th value {th_val} for atomic")
if is_store and th_val == 3: raise ValueError("invalid TH_STORE_LU for store")
if scope_val == 2 and th_val == 0: raise ValueError("invalid SCOPE_SA without th")
if inst.tfe and inst.d16 and th_val != 0: raise ValueError("invalid th with tfe+d16")
th_table = _TH_ATOMIC if is_atomic else (_TH_STORE if is_store else _TH_LOAD)
if (th_name := th_table.get(th_val)): mods.append(f"th:{th_name}")
if (scope_name := _SCOPE.get(scope_val)): mods.append(f"scope:{scope_name}")
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str} {' '.join(mods)}"
def _disasm_sop1(inst: SOP1) -> str:
op, name = inst.op, inst.op_name.lower()
if not name: raise ValueError(f"undefined SOP1 op: {inst.op}")
if _has(name, 'alloc_vgpr', 'sleep_var', 'barrier_signal', 'barrier_wait', 'wakeup_barrier'):
return f"{name} {inst.lit(inst.ssrc0) if inst.src_regs(0) == 1 else _fmt_src(inst.ssrc0, inst.src_regs(0))}"
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {_fmt_src(inst.ssrc0, 2)}"
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {_fmt_src(inst.ssrc0, 2)}"
@ -375,10 +488,14 @@ def _disasm_sopc(inst: SOPC) -> str:
def _disasm_sopk(inst: SOPK) -> str:
op, name = inst.op, inst.op_name.lower()
if not name: raise ValueError(f"undefined SOPK op: {inst.op}")
if op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}"
if op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32):
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
is_rdna4 = 'rdna4' in inst.__class__.__module__
hwreg_map = HWREG_RDNA4 if is_rdna4 else HWREG
if hid in (16, 17) or (is_rdna4 and hid not in hwreg_map): hs = f"0x{inst.simm16:x}"
else: hs = f"hwreg({hwreg_map.get(hid, str(hid))}, {hoff}, {hsz})"
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if op == SOPKOp.S_SETREG_B32 else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}"
@ -386,43 +503,36 @@ def _disasm_vinterp(inst: VINTERP) -> str:
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
return f"{inst.op_name.lower()} v{inst.vdst}, {inst.lit(inst.src0, inst.neg & 1)}, {inst.lit(inst.src1, inst.neg & 2)}, {inst.lit(inst.src2, inst.neg & 4)}" + (" " + mods if mods else "")
# RDNA4-specific handlers
def _disasm_ldsdir(inst) -> str:
wait = f" wait_vdst:{inst.wait_va}" if inst.wait_va != 0 else ""
if inst.op == 1: return f"lds_direct_load v{inst.vdst}{wait}"
if inst.op == 0: return f"lds_param_load v{inst.vdst}, attr{inst.attr}.{['x','y','z','w'][inst.attr_chan]}{wait}"
raise ValueError(f"unknown LDSDIR op: {inst.op}")
def _disasm_vexport(inst) -> str:
"""Disassemble VEXPORT (export) instruction - RDNA4 format."""
targets = ['mrt0','mrt1','mrt2','mrt3','mrt4','mrt5','mrt6','mrt7','mrtz','null','prim','pos0','pos1','pos2','pos3','pos4','param0','param1','param2','param3','param4','param5','param6','param7','param8','param9','param10','param11','param12','param13','param14','param15','param16','param17','param18','param19','param20','param21','param22','param23','param24','param25','param26','param27','param28','param29','param30','param31']
target = targets[inst.target] if inst.target < len(targets) else f"invalid_target_{inst.target}"
target = _EXP_TARGETS.get(inst.target, f"invalid_target_{inst.target}")
en = inst.en
vsrc = lambda i, v: f"v{v}" if (en >> i) & 1 else "off"
srcs = f"{vsrc(0, inst.vsrc0)}, {vsrc(1, inst.vsrc1)}, {vsrc(2, inst.vsrc2)}, {vsrc(3, inst.vsrc3)}"
mods = _mods((inst.done, "done"), (inst.row, "row_en"))
return f"export {target} {srcs}" + (" " + mods if mods else "")
prefix = "export" if 'rdna4' in inst.__class__.__module__ else "exp"
return f"{prefix} {target} {srcs}" + (" " + mods if mods else "")
def _disasm_vbuffer(inst) -> str:
"""Disassemble VBUFFER instruction - RDNA4 96-bit buffer format."""
name = inst.op_name.lower()
# Data width: derive from instruction name (d16 packs 2 components per register)
suffix = name.split('_')[-1]
base_w = {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'b8':1,'x':1,'xy':2,'xyz':3,'xyzw':4,'u32':1,'u64':2,'i32':1,'i64':2,'f32':1,'f64':2,'f16':1,'bf16':1}.get(suffix, 1)
if 'd16' in name:
w = (base_w + 1) // 2 # d16: pack 2 values per register (xy=1, xyz=2, xyzw=2)
else:
w = base_w
w = (base_w + 1) // 2 if 'd16' in name else base_w
if 'cmpswap' in name: w *= 2
if inst.tfe: w += 1
# vaddr
vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off"
# rsrc is stored directly (SGPR index), not multiplied by 4 like RDNA3 MUBUF
rsrc = _sreg_or_ttmp(inst.rsrc, 4)
# soffset
soffset = decode_src(inst.soffset)
# TH (temporal hint) and SCOPE names for RDNA4
rsrc, soffset = _sreg_or_ttmp(inst.rsrc, 4), decode_src(inst.soffset)
th_load = ['','TH_LOAD_RT','TH_LOAD_NT','TH_LOAD_HT','TH_LOAD_LU','TH_LOAD_NT_RT','TH_LOAD_NT_HT','TH_LOAD_BYPASS']
th_store = ['','TH_STORE_RT','TH_STORE_NT','TH_STORE_HT','','TH_STORE_NT_RT','TH_STORE_NT_HT','TH_STORE_BYPASS']
th_atomic = ['','TH_ATOMIC_NT','','','','TH_ATOMIC_RETURN','TH_ATOMIC_RT_RETURN','TH_ATOMIC_CASCADE_NT']
scope_names = ['','SCOPE_SE','SCOPE_DEV','SCOPE_SYS']
is_atomic, is_store = 'atomic' in name, 'store' in name and 'atomic' not in name
th_names = th_atomic if is_atomic else th_store if is_store else th_load
# Modifiers
mods = []
if inst.idxen: mods.append("idxen")
if inst.offen: mods.append("offen")
@ -430,75 +540,53 @@ def _disasm_vbuffer(inst) -> str:
if inst.th and inst.th < len(th_names) and th_names[inst.th]: mods.append(f"th:{th_names[inst.th]}")
if inst.scope and inst.scope < len(scope_names) and scope_names[inst.scope]: mods.append(f"scope:{scope_names[inst.scope]}")
if inst.tfe: mods.append("tfe")
mod_str = " " + " ".join(mods) if mods else ""
return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {rsrc}, {soffset}{mod_str}"
return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {rsrc}, {soffset}" + (" " + " ".join(mods) if mods else "")
def _disasm_vflat(inst) -> str:
"""Disassemble VFLAT/VGLOBAL/VSCRATCH instruction - RDNA4 96-bit flat format."""
name = inst.op_name.lower()
cls_name = type(inst).__name__
# Determine segment from class name (VFLAT, VGLOBAL, VSCRATCH)
seg = 'flat' if cls_name == 'VFLAT' else 'global' if cls_name == 'VGLOBAL' else 'scratch'
# Build instruction name: replace prefix with segment
parts = name.split('_', 1)
instr = f"{seg}_{parts[1]}" if len(parts) > 1 else name
# Data width: derive from instruction name
suffix = name.split('_')[-1]
w = {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'b8':1,'u32':1,'u64':2,'i32':1,'i64':2,'f32':1,'f64':2,'f16':1,'bf16':1}.get(suffix, 1)
if 'cmpswap' in name: w *= 2
# Signed offset (24-bit two's complement)
off_val = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000
# saddr handling
if seg == 'flat':
saddr_s = ""
addr_width = 2
elif inst.saddr == 0x7F or (hasattr(inst, 'sve') and inst.sve == 0 and seg == 'scratch'):
saddr_s = ", off"
addr_width = 2
elif inst.saddr == 124:
saddr_s = ", off"
addr_width = 2
else:
saddr_s = f", {_fmt_src(inst.saddr, 2) if inst.saddr <= 105 else decode_src(inst.saddr)}"
addr_width = 1
# vaddr
if seg == 'flat': saddr_s, addr_width = "", 2
elif inst.saddr == 0x7F or (hasattr(inst, 'sve') and inst.sve == 0 and seg == 'scratch'): saddr_s, addr_width = ", off", 2
elif inst.saddr == 124: saddr_s, addr_width = ", off", 2
else: saddr_s, addr_width = f", {_fmt_src(inst.saddr, 2) if inst.saddr <= 105 else decode_src(inst.saddr)}", 1
vaddr = f"v{inst.vaddr}" if addr_width == 1 else _vreg(inst.vaddr, 2)
# TH and SCOPE
th_load = ['','TH_LOAD_RT','TH_LOAD_NT','TH_LOAD_HT','TH_LOAD_LU','TH_LOAD_NT_RT','TH_LOAD_NT_HT','TH_LOAD_BYPASS']
th_store = ['','TH_STORE_RT','TH_STORE_NT','TH_STORE_HT','','TH_STORE_NT_RT','TH_STORE_NT_HT','TH_STORE_BYPASS']
th_atomic = ['','TH_ATOMIC_NT','','','','TH_ATOMIC_RETURN','TH_ATOMIC_RT_RETURN','TH_ATOMIC_CASCADE_NT']
scope_names = ['','SCOPE_SE','SCOPE_DEV','SCOPE_SYS']
is_atomic, is_store = 'atomic' in name, 'store' in name and 'atomic' not in name
th_names = th_atomic if is_atomic else th_store if is_store else th_load
# Modifiers
mods = []
if off_val: mods.append(f"offset:{off_val}")
if inst.th and inst.th < len(th_names) and th_names[inst.th]: mods.append(f"th:{th_names[inst.th]}")
if inst.scope and inst.scope < len(scope_names) and scope_names[inst.scope]: mods.append(f"scope:{scope_names[inst.scope]}")
mod_str = " " + " ".join(mods) if mods else ""
# Determine data register
if 'store' in name and 'atomic' not in name:
if 'store' in name and 'atomic' not in name: return f"{instr} {vaddr}, {_vreg(inst.vsrc, w)}{saddr_s}{mod_str}"
if 'atomic' in name:
if inst.th and inst.th >= 5: return f"{instr} {_vreg(inst.vdst, w)}, {vaddr}, {_vreg(inst.vsrc, w)}{saddr_s}{mod_str}"
return f"{instr} {vaddr}, {_vreg(inst.vsrc, w)}{saddr_s}{mod_str}"
elif 'atomic' in name:
if inst.th and inst.th >= 5: # TH_ATOMIC_RETURN variants
return f"{instr} {_vreg(inst.vdst, w)}, {vaddr}, {_vreg(inst.vsrc, w)}{saddr_s}{mod_str}"
else:
return f"{instr} {vaddr}, {_vreg(inst.vsrc, w)}{saddr_s}{mod_str}"
else: # load
return f"{instr} {_vreg(inst.vdst, w)}, {vaddr}{saddr_s}{mod_str}"
return f"{instr} {_vreg(inst.vdst, w)}, {vaddr}{saddr_s}{mod_str}"
# Handler mappings
DISASM_HANDLERS = {VOP1: _disasm_vop1, VOP2: _disasm_vop2, VOPC: _disasm_vopc, VOP3: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOPD: _disasm_vopd, VOP3P: _disasm_vop3p,
VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, MUBUF: _disasm_buf, MTBUF: _disasm_buf,
MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk}
# Name-based lookup for cross-architecture support (RDNA3/RDNA4 have different class objects with same names)
_DISASM_BY_NAME = {
'VOP1': _disasm_vop1, 'VOP2': _disasm_vop2, 'VOPC': _disasm_vopc, 'VOP3': _disasm_vop3, 'VOP3SD': _disasm_vop3sd,
'VOPD': _disasm_vopd, 'VOP3P': _disasm_vop3p, 'VINTERP': _disasm_vinterp, 'SOPP': _disasm_sopp, 'SMEM': _disasm_smem,
'DS': _disasm_ds, 'VDS': _disasm_ds, 'FLAT': _disasm_flat, 'MUBUF': _disasm_buf, 'MTBUF': _disasm_buf, 'MIMG': _disasm_mimg,
'SOP1': _disasm_sop1, 'SOP2': _disasm_sop2, 'SOPC': _disasm_sopc, 'SOPK': _disasm_sopk,
'VEXPORT': _disasm_vexport, 'EXP': _disasm_vexport,
'VEXPORT': _disasm_vexport, 'EXP': _disasm_vexport, 'LDSDIR': _disasm_ldsdir,
'VBUFFER': _disasm_vbuffer, 'VFLAT': _disasm_vflat, 'VGLOBAL': _disasm_vflat, 'VSCRATCH': _disasm_vflat,
'VSAMPLE': _disasm_vsample, 'VIMAGE': _disasm_vimage,
}
def disasm(inst: Inst) -> str:
@ -512,7 +600,7 @@ def disasm(inst: Inst) -> str:
SPEC_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), 'null': RawImm(124), 'off': RawImm(124), 'm0': RawImm(125),
'exec_lo': RawImm(126), 'exec_hi': RawImm(127), 'exec': RawImm(126), 'scc': RawImm(253), 'src_scc': RawImm(253)}
FLOATS = {str(k): k for k in FLOAT_ENC} # Valid float literal strings: '0.5', '-0.5', '1.0', etc.
FLOATS = {str(k): k for k in FLOAT_ENC}
REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp}
SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b128', 's_load_b256', 's_load_b512',
's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512'}
@ -554,7 +642,6 @@ def _extract(text: str, pat: str, flags=re.I):
def get_dsl(text: str) -> str:
text, kw = text.strip(), []
# Extract modifiers
for pat, val in [(r'\s+mul:2(?:\s|$)', 1), (r'\s+mul:4(?:\s|$)', 2), (r'\s+div:2(?:\s|$)', 3)]:
if (m := _extract(text, pat))[0]: kw.append(f'omod={val}'); text = m[1]; break
if (m := _extract(text, r'\s+clamp(?:\s|$)'))[0]: kw.append('clmp=1'); text = m[1]
@ -572,13 +659,11 @@ def get_dsl(text: str) -> str:
m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
if waitexp: kw.append(f'waitexp={waitexp}')
parts = text.replace(',', ' ').split()
if not parts: raise ValueError("empty instruction")
mn, op_str = parts[0].lower(), text[len(parts[0]):].strip()
ops, args = _parse_ops(op_str), [_op2dsl(o) for o in _parse_ops(op_str)]
# s_waitcnt
if mn == 's_waitcnt':
vm, exp, lgkm = 0x3f, 0x7, 0x3f
for p in op_str.replace(',', ' ').split():
@ -588,7 +673,6 @@ def get_dsl(text: str) -> str:
elif re.match(r'^0x[0-9a-f]+$|^\d+$', p): return f"s_waitcnt(simm16={int(p, 0)})"
return f"s_waitcnt(simm16={waitcnt(vm, exp, lgkm)})"
# VOPD
if '::' in text:
xp, yp = text.split('::')
xps, yps = xp.strip().replace(',', ' ').split(), yp.strip().replace(',', ' ').split()
@ -600,14 +684,12 @@ def get_dsl(text: str) -> str:
elif 'fmamk' in yps[0].lower() and len(yo) > 3: lit, vsy1 = yo[2], yo[3]
return f"VOPD(VOPDOp.{xps[0].upper()}, VOPDOp.{yps[0].upper()}, vdstx={vdx}, vdsty={vdy}, srcx0={sx0}, vsrcx1={vsx1}, srcy0={sy0}, vsrcy1={vsy1}{f', literal={lit}' if lit else ''})"
# Special instructions
if mn == 's_setreg_imm32_b32': raise ValueError(f"unsupported: {mn}")
if mn in ('s_setpc_b64', 's_rfe_b64'): return f"{mn}(ssrc0={args[0]})"
if mn in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'): return f"{mn}(sdst={args[0]}, ssrc0=RawImm({args[1].strip()}))"
if mn == 's_version': return f"{mn}(simm16={args[0]})"
if mn == 's_setreg_b32': return f"{mn}(simm16={args[0]}, sdst={args[1]})"
# SMEM
if mn in SMEM_OPS:
gs, ds = ", glc=1" if glc else "", ", dlc=1" if dlc else ""
if len(ops) >= 3 and re.match(r'^-?[0-9]|^-?0x', ops[2].strip().lower()):
@ -615,11 +697,9 @@ def get_dsl(text: str) -> str:
if off_val and len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}{gs}{ds})"
if len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, soffset={args[2]}{gs}{ds})"
# Buffer
if mn.startswith('buffer_') and len(ops) >= 2 and ops[1].strip().lower() == 'off':
return f"{mn}(vdata={args[0]}, vaddr=0, srsrc={args[2]}, soffset={f'RawImm({args[3].strip()})' if len(args) > 3 else 'RawImm(0)'})"
# FLAT/GLOBAL/SCRATCH load/store/atomic - saddr needs RawImm(124) for off/null
def _saddr(a): return 'RawImm(124)' if a in ('OFF', 'NULL') else a
flat_mods = f"{f', offset={off_val}' if off_val else ''}{', glc=1' if glc else ''}{', slc=1' if slc else ''}{', dlc=1' if dlc else ''}"
for pre, flds in [('flat_load','vdst,addr,saddr'), ('global_load','vdst,addr,saddr'), ('scratch_load','vdst,addr,saddr'),
@ -632,7 +712,6 @@ def get_dsl(text: str) -> str:
if glc and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3])}' if len(args) >= 4 else ', saddr=RawImm(124)'}{flat_mods})"
if len(args) >= 2: return f"{mn}(addr={args[0]}, data={args[1]}{f', saddr={_saddr(args[2])}' if len(args) >= 3 else ', saddr=RawImm(124)'}{flat_mods})"
# DS instructions
if mn.startswith('ds_'):
off0, off1 = (str(int(off_val, 0) & 0xff), str((int(off_val, 0) >> 8) & 0xff)) if off_val else ("0", "0")
gds_s = ", gds=1" if 'gds' in text.lower().split()[-1:] else ""
@ -656,12 +735,10 @@ def get_dsl(text: str) -> str:
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})"
# v_fmaak/v_fmamk literal extraction
lit_s = ""
if mn in ('v_fmaak_f32', 'v_fmaak_f16') and len(args) == 4: lit_s, args = f", literal={args[3].strip()}", args[:3]
elif mn in ('v_fmamk_f32', 'v_fmamk_f16') and len(args) == 4: lit_s, args = f", literal={args[2].strip()}", [args[0], args[1], args[3]]
# VCC ops cleanup
vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'}
if mn.replace('_e32', '') in vcc_ops and len(args) >= 5: mn, args = mn.replace('_e32', '') + '_e32', [args[0], args[2], args[3]]
if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '')
@ -671,7 +748,6 @@ def get_dsl(text: str) -> str:
fn = mn.replace('.', '_')
if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args]
# v_fma_mix*: extract inline neg/abs modifiers
if 'fma_mix' in mn and neg_lo is None and neg_hi is None:
inline_neg, inline_abs, clean_args = 0, 0, [args[0]]
for i, op in enumerate(ops[1:4]):

View file

@ -154,7 +154,7 @@ class VIMAGE(Inst96):
scope = bits[51:50]
th = bits[54:52]
tfe = bits[55]
vaddr4 = bits[56:63]
vaddr4 = bits[63:56]
vaddr0 = bits[71:64]
vaddr1 = bits[79:72]
vaddr2 = bits[87:80]

View file

@ -54,7 +54,7 @@ _SPECIAL_REGS = {
'V_CMP_CLASS_F64': (1, 2, 1, 1), 'V_CMPX_CLASS_F64': (1, 2, 1, 1),
'V_CMP_CLASS_F32': (1, 1, 1, 1), 'V_CMPX_CLASS_F32': (1, 1, 1, 1),
'V_CMP_CLASS_F16': (1, 1, 1, 1), 'V_CMPX_CLASS_F16': (1, 1, 1, 1),
'V_MAD_U64_U32': (2, 1, 1, 2), 'V_MAD_I64_I32': (2, 1, 1, 2),
'V_MAD_U64_U32': (2, 1, 1, 2), 'V_MAD_I64_I32': (2, 1, 1, 2), 'V_MAD_CO_U64_U32': (2, 1, 1, 2), 'V_MAD_CO_I64_I32': (2, 1, 1, 2),
'V_QSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_U32_U8': (4, 2, 1, 4),
# RDNA4 CVT_PK_F32 instructions output 2 F32 values (64-bit)
'V_CVT_PK_F32_BF8': (2, 1, 1, 1), 'V_CVT_PK_F32_FP8': (2, 1, 1, 1),
@ -98,7 +98,7 @@ def spec_is_16bit(name: str) -> bool:
def spec_is_64bit(name: str) -> bool: return bool(re.search(r'_[FIUB]64(?:_|$)', name.upper()))
_3SRC = {'FMA', 'MAD', 'MIN3', 'MAX3', 'MED3', 'DIV_FIX', 'DIV_FMAS', 'DIV_SCALE', 'SAD', 'LERP', 'ALIGN', 'CUBE', 'BFE', 'BFI',
'PERM_B32', 'PERMLANE', 'CNDMASK', 'XOR3', 'OR3', 'ADD3', 'LSHL_OR', 'AND_OR', 'LSHL_ADD', 'ADD_LSHL', 'XAD', 'MAXMIN',
'MINMAX', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'}
'MINMAX', 'MAXIMUMMINIMUM', 'MINIMUMMAXIMUM', 'MAXIMUM3', 'MINIMUM3', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'}
_2SRC = {'FMAC'} # FMAC uses dst as implicit accumulator, so only 2 explicit sources
def spec_num_srcs(name: str) -> int:
name = name.upper()
@ -138,9 +138,14 @@ class BitField:
# Convert to IntEnum if marker is an IntEnum subclass
if self.marker and isinstance(self.marker, type) and issubclass(self.marker, IntEnum):
# VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
if self.marker is VOP3Op:
if val < 256: return VOPCOp(val)
if val in Inst._VOP3SD_OPS: return VOP3SDOp(val)
# Check by name to handle both RDNA3 and RDNA4 enums
if self.marker.__name__ == 'VOP3Op':
# Get the appropriate enums from the same module as the marker
marker_mod = self.marker.__module__
import importlib
enum_mod = importlib.import_module(marker_mod)
if val < 256: return enum_mod.VOPCOp(val)
if val in Inst._VOP3SD_OPS: return enum_mod.VOP3SDOp(val)
try: return self.marker(val)
except ValueError: pass
# For RDNA4 op fields without type annotations, dynamically look up enum

View file

@ -38,6 +38,9 @@ RDNA3_TEST_FILES = {
'mubuf': 'gfx11_asm_mubuf.s',
'mtbuf': 'gfx11_asm_mtbuf.s',
'mimg': 'gfx11_asm_mimg.s',
'ldsdir': 'gfx11_asm_ldsdir.s',
# Export
'exp': 'gfx11_asm_exp.s',
# WMMA
'wmma': 'gfx11_asm_wmma.s',
# Features
@ -162,12 +165,12 @@ class TestLLVMRDNA3(unittest.TestCase):
@classmethod
def setUpClass(cls):
from extra.assembly.amd.autogen.rdna3.ins import SOP1, SOP2, SOPC, SOPK, SOPP, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, DS, SMEM, FLAT, MUBUF, MTBUF, MIMG
from extra.assembly.amd.autogen.rdna3.ins import SOP1, SOP2, SOPC, SOPK, SOPP, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, DS, SMEM, FLAT, MUBUF, MTBUF, MIMG, LDSDIR, EXP
cls.formats = {
'sop1': SOP1, 'sop2': SOP2, 'sopc': SOPC, 'sopk': SOPK, 'sopp': SOPP,
'vop1': VOP1, 'vop2': VOP2, 'vopc': VOPC, 'vopcx': VOPC, 'vop3': VOP3, 'vop3p': VOP3P,
'vinterp': VINTERP, 'vopd': VOPD, 'ds': DS, 'smem': SMEM, 'flat': FLAT,
'mubuf': MUBUF, 'mtbuf': MTBUF, 'mimg': MIMG, 'wmma': VOP3P,
'mubuf': MUBUF, 'mtbuf': MTBUF, 'mimg': MIMG, 'wmma': VOP3P, 'ldsdir': LDSDIR, 'exp': EXP,
'vop3_from_vop1': VOP3, 'vop3_from_vop2': VOP3, 'vop3_from_vopc': VOP3, 'vop3_from_vopcx': VOP3,
'vop3_features': VOP3, 'vop3p_features': VOP3P, 'vopd_features': VOPD,
'vop3_alias': VOP3, 'vop3p_alias': VOP3P, 'vopc_alias': VOPC, 'vopcx_alias': VOPC,
@ -260,7 +263,7 @@ class TestLLVMRDNA4(unittest.TestCase):
'ds': VDS, 'ds_alias': VDS, 'smem': SMEM,
'vinterp': VINTERP, 'exp': VEXPORT,
'vbuffer_mubuf': VBUFFER, 'vbuffer_mubuf_alias': VBUFFER, 'vbuffer_mtbuf': VBUFFER, 'vbuffer_mtbuf_alias': VBUFFER,
'vdsdir': VDSDIR, 'vdsdir_alias': VDSDIR,
'vdsdir': None, 'vdsdir_alias': None, # VDSDIR is 64-bit but ds_direct_load is 32-bit
'vflat': VFLAT, 'vflat_alias': VFLAT,
'vimage': VIMAGE, 'vimage_alias': VIMAGE,
'vsample': VSAMPLE,