tinygrad/extra/assembly/amd/disasm.py
George Hotz 91bde927ef
assembly/amd: split asm.py into asm.py and disasm.py (#14101)
* split asm.py into asm.py and disasm.py

* split decoder

* move to pcode

* tests
2026-01-12 07:22:02 +09:00

745 lines
49 KiB
Python

# RDNA3/RDNA4/CDNA disassembler
from __future__ import annotations
import re
from extra.assembly.amd.dsl import Inst, decode_src, SPECIAL_GPRS, SPECIAL_GPRS_CDNA, SPECIAL_PAIRS, SPECIAL_PAIRS_CDNA
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP,
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp)
from extra.assembly.amd.autogen.rdna3.enum import BufFmt
from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP2 as R4_VOP2, VOP3 as R4_VOP3, VOP3SD as R4_VOP3SD, VOP3P as R4_VOP3P,
VOPC as R4_VOPC, VOPD as R4_VOPD, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP2 as R4_SOP2, SOPC as R4_SOPC, SOPK as R4_SOPK, SOPP as R4_SOPP,
SMEM as R4_SMEM, DS as R4_DS, VBUFFER as R4_VBUFFER, VEXPORT as R4_VEXPORT, VOPDOp as R4_VOPDOp)
from extra.assembly.amd.autogen.cdna.ins import FLAT as C_FLAT, MUBUF as C_MUBUF, MTBUF as C_MTBUF
def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__
# ═══════════════════════════════════════════════════════════════════════════════
# CONSTANTS
# ═══════════════════════════════════════════════════════════════════════════════
HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_HW_ID', 5: 'HW_REG_GPR_ALLOC',
6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS', 15: 'HW_REG_SH_MEM_BASES', 18: 'HW_REG_PERF_SNAPSHOT_PC_LO',
19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK',
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'}
HWREG_RDNA4 = {1: 'HW_REG_WAVE_MODE', 2: 'HW_REG_WAVE_STATUS', 4: 'HW_REG_WAVE_STATE_PRIV', 5: 'HW_REG_WAVE_GPR_ALLOC',
6: 'HW_REG_WAVE_LDS_ALLOC', 7: 'HW_REG_IB_STS', 10: 'HW_REG_PERF_SNAPSHOT_DATA', 11: 'HW_REG_PERF_SNAPSHOT_PC_LO',
12: 'HW_REG_PERF_SNAPSHOT_PC_HI', 15: 'HW_REG_PERF_SNAPSHOT_DATA1', 16: 'HW_REG_PERF_SNAPSHOT_DATA2',
17: 'HW_REG_WAVE_EXCP_FLAG_PRIV', 18: 'HW_REG_WAVE_EXCP_FLAG_USER', 19: 'HW_REG_WAVE_TRAP_CTRL',
20: 'HW_REG_WAVE_SCRATCH_BASE_LO', 21: 'HW_REG_WAVE_SCRATCH_BASE_HI', 23: 'HW_REG_WAVE_HW_ID1',
24: 'HW_REG_WAVE_HW_ID2', 26: 'HW_REG_WAVE_SCHED_MODE', 29: 'HW_REG_SHADER_CYCLES_LO',
30: 'HW_REG_SHADER_CYCLES_HI', 31: 'HW_REG_WAVE_DVGPR_ALLOC_LO', 32: 'HW_REG_WAVE_DVGPR_ALLOC_HI'}
MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA',
131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA',
134: 'MSG_RTN_GET_TBA_TO_PC', 135: 'MSG_RTN_GET_SE_AID_ID'}
# CDNA opcode name aliases for disasm (new name -> old name expected by tests)
_CDNA_DISASM_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32', 'v_fmamk_f32': 'v_madmk_f32', 'v_fmaak_f32': 'v_madak_f32'}
# ═══════════════════════════════════════════════════════════════════════════════
# HELPERS
# ═══════════════════════════════════════════════════════════════════════════════
def _reg(p: str, b: int, n: int = 1) -> str: return f"{p}{b}" if n == 1 else f"{p}[{b}:{b+n-1}]"
def _sreg(b: int, n: int = 1) -> str: return _reg("s", b, n)
def _vreg(b: int, n: int = 1) -> str: return _reg("v", b, n)
def _areg(b: int, n: int = 1) -> str: return _reg("a", b, n) # accumulator registers for GFX90a
def _ttmp(b: int, n: int = 1) -> str: return _reg("ttmp", b - 108, n) if 108 <= b <= 123 else None
def _sreg_or_ttmp(b: int, n: int = 1) -> str: return _ttmp(b, n) or _sreg(b, n)
def _fmt_sdst(v: int, n: int = 1, cdna: bool = False) -> str:
if t := _ttmp(v, n): return t
pairs = SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS
gprs = SPECIAL_GPRS_CDNA if cdna else SPECIAL_GPRS
if n > 1: return pairs.get(v) or gprs.get(v) or _sreg(v, n) # also check gprs for null/m0
return gprs.get(v, f"s{v}")
def _fmt_src(v: int, n: int = 1, cdna: bool = False) -> str:
if n == 1: return decode_src(v, cdna)
if v >= 256: return _vreg(v - 256, n)
if v <= 101: return _sreg(v, n) # s0-s101 can be pairs, but 102+ are special on CDNA
pairs = SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS
if n == 2 and v in pairs: return pairs[v]
if v <= 105: return _sreg(v, n) # s102-s105 regular pairs for RDNA
if t := _ttmp(v, n): return t
return decode_src(v, cdna)
def _fmt_v16(v: int, base: int = 256, hi_thresh: int = 384) -> str:
return f"v{(v - base) & 0x7f}.{'h' if v >= hi_thresh else 'l'}"
def _has(op: str, *subs) -> bool: return any(s in op for s in subs)
def _omod(v: int) -> str: return {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(v, "")
def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v) # format 16-bit src: vgpr.h/l or literal
def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c)
def _fmt_bits(label: str, val: int, count: int) -> str: return f"{label}:[{','.join(str((val >> i) & 1) for i in range(count))}]"
def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool) -> str:
"""Format VOP3 source operand with modifiers."""
if v == 255: s = inst.lit(v) # literal constant takes priority
elif n > 1: s = _fmt_src(v, n)
elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else f"v{v - 256}.l"
elif v == 253: s = "src_scc" # VOP3 sources use src_scc not scc
else: s = inst.lit(v)
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
"""Format op_sel modifier string."""
if not need: return ""
dst_hi = (opsel >> 3) & 1
if n == 1: return f" op_sel:[{opsel & 1},{dst_hi}]"
if n == 2: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{dst_hi}]"
return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{dst_hi}]"
# ═══════════════════════════════════════════════════════════════════════════════
# DISASSEMBLER
# ═══════════════════════════════════════════════════════════════════════════════
def _disasm_vop1(inst: VOP1) -> str:
name, cdna = inst.op_name.lower() or f'vop1_op_{inst.op}', _is_cdna(inst)
suf = "" if cdna else "_e32"
if name in ('v_nop', 'v_pipeflush', 'v_clrexcp'): return name # no operands
if 'readfirstlane' in name:
src = f"v{inst.src0 - 256}" if inst.src0 >= 256 else decode_src(inst.src0, cdna)
return f"{name} {_fmt_sdst(inst.vdst, 1, cdna)}, {src}"
# 16-bit dst: uses .h/.l suffix for RDNA (CDNA uses plain vN)
parts = name.split('_')
is_16d = not cdna and (any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name))
# v_cvt_pk_f32_fp8 and v_cvt_pk_f32_bf8 output to 2 VGPRs, and take 16-bit src
is_pk_fp8 = 'cvt_pk_f32_fp8' in name or 'cvt_pk_f32_bf8' in name
dregs = 2 if is_pk_fp8 else inst.dst_regs()
dst = _vreg(inst.vdst, dregs) if dregs > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if not cdna and (inst.is_src_16(0) or is_pk_fp8) and 'sat_pk' not in name else inst.lit(inst.src0)
return f"{name}{suf} {dst}, {src}"
_VOP2_CARRY_OUT = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32'} # carry out only
_VOP2_CARRY_INOUT = {'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'} # carry in and out (CDNA)
_VOP2_CARRY_INOUT_RDNA = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'} # carry in and out (RDNA)
def _disasm_vop2(inst: VOP2) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if cdna: name = _CDNA_DISASM_ALIASES.get(name, name) # apply CDNA aliases
suf = "" if cdna or (not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16) else "_e32"
lit = getattr(inst, '_literal', None)
is16 = not cdna and inst.is_16bit()
# fmaak/madak: dst = src0 * vsrc1 + K, fmamk/madmk: dst = src0 * K + vsrc1
if 'fmaak' in name or 'madak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)):
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}, 0x{lit:x}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{lit:x}"
if 'fmamk' in name or 'madmk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)):
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, 0x{lit:x}, {_fmt_v16(inst.vsrc1, 0, 128)}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{lit:x}, v{inst.vsrc1}"
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
vcc = "vcc" if cdna else "vcc_lo"
# CDNA carry ops output vcc after vdst
if cdna and name in _VOP2_CARRY_OUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}"
if cdna and name in _VOP2_CARRY_INOUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}, {vcc}"
# RDNA carry-in/out ops: v_add_co_ci_u32, etc. - format: vdst, vcc_lo, src0, vsrc1, vcc_lo
if not cdna and name in _VOP2_CARRY_INOUT_RDNA: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}, {vcc}"
# Handle 64-bit register operands (v_add_f64, v_mul_f64, etc.)
dn, sn0, sn1 = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1)
if dn > 1 or sn0 > 1 or sn1 > 1:
dst = _vreg(inst.vdst, dn) if dn > 1 else f"v{inst.vdst}"
src0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, sn0, cdna)
src1 = _vreg(inst.vsrc1, sn1) if sn1 > 1 else f"v{inst.vsrc1}"
return f"{name} {dst}, {src0}, {src1}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (f", {vcc}" if name == 'v_cndmask_b32' else "")
def _disasm_vopc(inst: VOPC) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if cdna:
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna)
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else f"v{inst.vsrc1}"
return f"{name} vcc, {s0}, {s1}" # CDNA VOPC always outputs vcc
# RDNA: v_cmpx_* writes to exec (no vcc), v_cmp_* writes to vcc_lo
has_vcc = 'cmpx' not in name
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}"
return f"{name}_e32 vcc_lo, {s0}, {s1}" if has_vcc else f"{name}_e32 {s0}, {s1}"
NO_ARG_SOPP = {SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE, SOPPOp.S_TTRACEDATA}
_CDNA_NO_ARG_SOPP = {'s_endpgm', 's_barrier', 's_wakeup', 's_icache_inv', 's_ttracedata', 's_nop', 's_sethalt', 's_sleep',
's_setprio', 's_trap', 's_incperflevel', 's_decperflevel', 's_sendmsg', 's_sendmsghalt'}
def _disasm_sopp(inst: SOPP) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if cdna:
if name == 's_endpgm': return name if inst.simm16 == 0 else f"{name} {inst.simm16}"
if name in ('s_barrier', 's_wakeup', 's_icache_inv', 's_ttracedata'): return name
if name == 's_waitcnt':
vm, lgkm, exp = inst.simm16 & 0xf, (inst.simm16 >> 8) & 0x3f, (inst.simm16 >> 4) & 0x7
p = [f"vmcnt({vm})" if vm != 0xf else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
if name.startswith(('s_cbranch', 's_branch')): return f"{name} {inst.simm16}"
return f"{name} 0x{inst.simm16:x}" if inst.simm16 else name
# RDNA
if inst.op in NO_ARG_SOPP: return name
if inst.op == SOPPOp.S_ENDPGM: return name if inst.simm16 == 0 else f"{name} {inst.simm16}"
if inst.op == SOPPOp.S_WAITCNT:
vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f
p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
if inst.op == SOPPOp.S_DELAY_ALU:
deps, skips = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3'], ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4']
id0, skip, id1 = inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x7, (inst.simm16 >> 7) & 0xf
dep = lambda v: deps[v-1] if 0 < v <= len(deps) else str(v)
p = [f"instid0({dep(id0)})" if id0 else "", f"instskip({skips[skip]})" if skip else "", f"instid1({dep(id1)})" if id1 else ""]
return f"s_delay_alu {' | '.join(x for x in p if x) or '0'}"
return f"{name} {inst.simm16}" if name.startswith(('s_cbranch', 's_branch')) else f"{name} 0x{inst.simm16:x}"
def _disasm_smem(inst: SMEM) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
soe, imm = getattr(inst, 'soe', 0), getattr(inst, 'imm', 1)
is_rdna4 = 'rdna4' in inst.__class__.__module__
offset = inst.ioffset if is_rdna4 else getattr(inst, 'offset', 0)
if cdna:
if soe and imm: off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{offset:x}"
elif imm: off_s = f"0x{offset:x}"
elif offset < 256: off_s = decode_src(offset, cdna)
else: off_s = decode_src(inst.soffset, cdna)
elif offset and inst.soffset != 124: off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{offset:x}"
elif offset: off_s = f"0x{offset:x}"
else: off_s = decode_src(inst.soffset, cdna)
is_buffer = 'buffer' in name or 's_atc_probe_buffer' == name
sbase_idx, sbase_count = inst.sbase * 2, 4 if is_buffer else 2
sbase_str = _fmt_src(sbase_idx, sbase_count, cdna) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}"
if 'prefetch' in name:
off = getattr(inst, 'ioffset', inst.offset)
if off >= 0x800000: off = off - 0x1000000
off_s = f"0x{off:x}" if off > 255 else str(off)
soff_s = decode_src(inst.soffset, cdna) if inst.soffset != 124 else "null"
if 'pc_rel' in name: return f"{name} {off_s}, {soff_s}, {inst.sdata}"
return f"{name} {sbase_str}, {off_s}, {soff_s}, {inst.sdata}"
th, scope = getattr(inst, 'th', 0), getattr(inst, 'scope', 0)
if th or scope:
th_names = ['TH_LOAD_RT', 'TH_LOAD_NT', 'TH_LOAD_HT', 'TH_LOAD_LU']
scope_names = ['SCOPE_CU', 'SCOPE_SE', 'SCOPE_DEV', 'SCOPE_SYS']
mods = (f" th:{th_names[th]}" if th else "") + (f" scope:{scope_names[scope]}" if scope else "")
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs(), cdna)}, {sbase_str}, {off_s}{mods}"
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs(), cdna)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (getattr(inst, 'dlc', 0), " dlc"))
def _disasm_flat(inst: FLAT) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
acc = getattr(inst, 'acc', 0)
reg_fn = _areg if acc else _vreg
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
w = inst.dst_regs() * (2 if '_x2' in name else 1) * (2 if 'cmpswap' in name else 1)
off_s = f" offset:{off_val}" if off_val else ""
if cdna: mods = f"{off_s}{' glc' if inst.sc0 else ''}{' slc' if inst.nt else ''}"
else: mods = f"{off_s}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
elif inst.saddr == 124: saddr_s = ", off"
elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr, cdna)}"
elif inst.saddr in (SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS): saddr_s = f", {(SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS)[inst.saddr]}"
elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}"
else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr, cdna)}"
if 'addtid' in name: return f"{instr} {'a' if acc else 'v'}{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}"
if cdna: addr_w = 1 if seg == 'scratch' else 2
else: addr_w = 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, addr_w)
data_s, vdst_s = reg_fn(inst.data, w), reg_fn(inst.vdst, w // 2 if 'cmpswap' in name else w)
glc_or_sc0 = inst.sc0 if cdna else inst.glc
if 'atomic' in name:
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}"
return f"{instr} {reg_fn(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
def _disasm_ds(inst: DS) -> str:
op, name = inst.op, inst.op_name.lower()
acc = getattr(inst, 'acc', 0)
reg_fn = _areg if acc else _vreg
rp = 'a' if acc else 'v'
gds = " gds" if inst.gds else ""
off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else ""
off2 = (" offset0:" + str(inst.offset0) if inst.offset0 else "") + (" offset1:" + str(inst.offset1) if inst.offset1 else "")
w = inst.dst_regs()
d0, d1, dst, addr = reg_fn(inst.data0, w), reg_fn(inst.data1, w), reg_fn(inst.vdst, w), f"v{inst.addr}"
if op == DSOp.DS_NOP: return name
if op == DSOp.DS_BVH_STACK_RTN_B32: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}"
if 'bvh_stack_push' in name:
d1_regs = 8 if 'push8' in name else 4
vdst_regs = 2 if 'pop2' in name else 1
vdst_s = _vreg(inst.vdst, vdst_regs) if vdst_regs > 1 else f"v{inst.vdst}"
return f"{name} {vdst_s}, {addr}, v{inst.data0}, {_vreg(inst.data1, d1_regs)}{off}{gds}"
if 'gws_sema' in name and op != DSOp.DS_GWS_SEMA_BR: return f"{name}{off}{gds}"
if 'gws_' in name: return f"{name} {addr}{off}{gds}"
if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} {rp}{inst.vdst}{off}{gds}"
if 'gs_reg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {rp}{inst.data0}{off}{gds}"
if '2addr' in name:
if 'load' in name: return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}{off2}{gds}"
if 'store' in name and 'xchg' not in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}, {d0}, {d1}{off2}{gds}"
if 'write2' in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
if 'read2' in name: return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}{off2}{gds}"
if 'load' in name: return f"{name} {rp}{inst.vdst}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
if 'store' in name and not _has(name, 'cmp', 'xchg'):
return f"{name} {rp}{inst.data0}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}"
if 'swizzle' in name or op == DSOp.DS_ORDERED_COUNT: return f"{name} {rp}{inst.vdst}, {addr}{off}{gds}"
if 'permute' in name: return f"{name} {rp}{inst.vdst}, {addr}, {rp}{inst.data0}{off}{gds}"
if 'condxchg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {addr}, {reg_fn(inst.data0, 2)}{off}{gds}"
if _has(name, 'cmpstore', 'mskor', 'wrap'):
return f"{name} {dst}, {addr}, {d0}, {d1}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}, {d1}{off}{gds}"
return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}"
def _disasm_vop3(inst: VOP3) -> str:
op, name = inst.op, inst.op_name.lower()
# RDNA4 v_s_* scalar VOP3 instructions
if name.startswith('v_s_'):
src = inst.lit(inst.src0) if inst.src0 == 255 else ("src_scc" if inst.src0 == 253 else _fmt_src(inst.src0, inst.src_regs(0)))
if inst.neg & 1: src = f"-{src}"
if inst.abs & 1: src = f"|{src}|"
clamp = inst.cm if 'cm' in inst._fields else getattr(inst, 'clmp', 0)
return f"{name} s{inst.vdst}, {src}" + (" clamp" if clamp else "") + _omod(inst.omod)
# VOP3SD (shared encoding)
if isinstance(op, VOP3SDOp):
sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs
def src(v, neg, n):
s = inst.lit(v) if v == 255 else ("src_scc" if v == 253 else (_fmt_src(v, n) if n > 1 else inst.lit(v)))
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {srcs}" + _omod(inst.omod)
# Detect 16-bit operand sizes
is16_d = is16_s = is16_s2 = False
if 'cvt_pk' in name: is16_s = name.endswith('16')
elif m := re.match(r'v_(?:cvt|frexp_exp)_([a-z0-9_]+)_([a-z0-9]+)', name):
is16_d, is16_s = _has(m.group(1), 'f16','i16','u16','b16'), _has(m.group(2), 'f16','i16','u16','b16')
is16_s2 = is16_s
elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True
elif 'pack_b32' in name: is16_s = is16_s2 = True
elif 'sat_pk' in name: is16_d = True
else: is16_d = is16_s = is16_s2 = inst.is_16bit()
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s)
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s)
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2)
# Destination
dn = inst.dst_regs()
if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1)
elif dn > 1: dst = _vreg(inst.vdst, dn)
elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l"
else: dst = f"v{inst.vdst}"
clamp = inst.cm if 'cm' in inst._fields else getattr(inst, 'clmp', 0)
cl, om = " clamp" if clamp else "", _omod(inst.omod)
nonvgpr_opsel = (inst.src0 < 256 and (inst.opsel & 1)) or (inst.src1 < 256 and (inst.opsel & 2)) or (inst.src2 < 256 and (inst.opsel & 4))
need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s)
if inst.op < 256: # VOPC
return f"{name}_e64 {s0}, {s1}{cl}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}{cl}"
if inst.op < 384: # VOP2
n = inst.num_srcs()
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}"
if inst.op < 512: # VOP1
if re.match(r'v_cvt_f32_(bf|fp)8', name) and inst.opsel:
os = f" byte_sel:{((inst.opsel & 1) << 1) | ((inst.opsel >> 1) & 1)}"
else:
os = _opsel_str(inst.opsel, 1, need_opsel, is16_d)
return f"{name}_e64" if op in (VOP3Op.V_NOP, VOP3Op.V_PIPEFLUSH) else f"{name}_e64 {dst}, {s0}{os}{cl}{om}"
# Native VOP3
n = inst.num_srcs()
if 'cvt_sr' in name and inst.opsel:
os = f" byte_sel:{inst.opsel >> 2}"
else:
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}"
def _disasm_vop3sd(inst: VOP3SD) -> str:
name = inst.op_name.lower()
src2_n = 2 if '_co_' in name and '64' in name else inst.src_regs(2)
def src(v, neg, n):
s = inst.lit(v) if v == 255 else ("src_scc" if v == 253 else (_fmt_src(v, n) if n > 1 else inst.lit(v)))
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, src2_n)
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
clamp = inst.cm if 'cm' in inst._fields else getattr(inst, 'clmp', 0)
return f"{name} {dst}, {_fmt_sdst(inst.sdst, 1)}, {srcs}{' clamp' if clamp else ''}{_omod(inst.omod)}"
def _disasm_vopd(inst: VOPD) -> str:
lit = inst._literal or inst.literal
is_rdna4 = 'rdna4' in inst.__class__.__module__
op_enum = R4_VOPDOp if is_rdna4 else VOPDOp
vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), op_enum(inst.opx).name.lower(), op_enum(inst.opy).name.lower()
def half(n, vd, s0, vs1):
if 'mov' in n: return f"{n} v{vd}, {inst.lit(s0)}"
if 'fmamk' in n and lit: return f"{n} v{vd}, {inst.lit(s0)}, 0x{lit:x}, v{vs1}"
if 'fmaak' in n and lit: return f"{n} v{vd}, {inst.lit(s0)}, v{vs1}, 0x{lit:x}"
return f"{n} v{vd}, {inst.lit(s0)}, v{vs1}"
return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}"
def _swmmac_regs(name: str) -> tuple[int, int, int, int]:
"""Return (dst, src0, src1, src2) register counts for SWMMAC instructions."""
if 'f16_16x16x32' in name or 'bf16_16x16x32' in name: return (4, 4, 8, 1)
if 'f32_16x16x32_f16' in name or 'f32_16x16x32_bf16' in name: return (8, 4, 8, 1)
if 'i32_16x16x32_iu4' in name: return (8, 1, 2, 1)
if 'i32_16x16x64_iu4' in name: return (8, 2, 4, 1)
if 'i32_16x16x32_iu8' in name or 'f32_16x16x32_fp8' in name or 'f32_16x16x32_bf8' in name: return (8, 2, 4, 1)
return (8, 8, 8, 8)
def _disasm_vop3p(inst: VOP3P) -> str:
name = inst.op_name.lower()
is_wmma, is_swmmac, n, is_fma_mix = 'wmma' in name, 'swmmac' in name, inst.num_srcs(), 'fma_mix' in name
def get_src(v, sc): return inst.lit(v) if v == 255 else _fmt_src(v, sc)
if is_swmmac:
dn, s0n, s1n, s2n = _swmmac_regs(name)
src0, src1, src2, dst = get_src(inst.src0, s0n), get_src(inst.src1, s1n), get_src(inst.src2, s2n), _vreg(inst.vdst, dn)
elif is_wmma:
is_rdna4_wmma = 'rdna4' in inst.__class__.__module__
sc = 1 if '16x16x16_iu4' in name else 2 if ('iu4' in name or 'iu8' in name or 'fp8' in name or 'bf8' in name) else 4
if not is_rdna4_wmma: sc *= 2
dc = 8 if not is_rdna4_wmma else (4 if ('f16_16x16' in name or 'bf16_16x16' in name) and 'f32' not in name else 8)
src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, dc), _vreg(inst.vdst, dc)
else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}"
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
clamp = inst.cm if 'cm' in inst._fields else getattr(inst, 'clmp', 0)
if is_fma_mix:
def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s)
src0, src1, src2 = m(src0, inst.neg & 1, inst.neg_hi & 1), m(src1, inst.neg & 2, inst.neg_hi & 2), m(src2, inst.neg & 4, inst.neg_hi & 4)
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if clamp else [])
elif is_swmmac:
mods = ([f"index_key:{inst.opsel}"] if inst.opsel else []) + ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + \
([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if clamp else [])
else:
opsel_hi_default = 7 if n == 3 else 3
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != opsel_hi_default else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if clamp else [])
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
def _disasm_buf(inst: MUBUF | MTBUF) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
acc = getattr(inst, 'acc', 0)
reg_fn = _areg if acc else _vreg
if cdna and name in ('buffer_wbl2', 'buffer_inv'): return name
if not cdna and inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \
{'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1)
if hasattr(inst, 'tfe') and inst.tfe: w += 1
vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off"
srsrc = _sreg_or_ttmp(inst.srsrc*4, 4)
is_mtbuf = isinstance(inst, MTBUF) or isinstance(inst, C_MTBUF)
if is_mtbuf:
dfmt, nfmt = inst.format & 0xf, (inst.format >> 4) & 0x7
if acc: fmt_s = f" dfmt:{dfmt}, nfmt:{nfmt},"
elif not cdna: fmt_s = f" format:{inst.format}" if inst.format else ""
else:
dfmt_names = ['INVALID', '8', '16', '8_8', '32', '16_16', '10_11_11', '11_11_10', '10_10_10_2', '2_10_10_10', '8_8_8_8', '32_32', '16_16_16_16', '32_32_32', '32_32_32_32', 'RESERVED_15']
nfmt_names = ['UNORM', 'SNORM', 'USCALED', 'SSCALED', 'UINT', 'SINT', 'RESERVED_6', 'FLOAT']
if dfmt == 1 and nfmt == 0: fmt_s = ""
elif nfmt == 0: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]}]"
elif dfmt == 1: fmt_s = f" format:[BUF_NUM_FORMAT_{nfmt_names[nfmt]}]"
else: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]},BUF_NUM_FORMAT_{nfmt_names[nfmt]}]"
else: fmt_s = ""
if cdna: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"glc"),(inst.nt,"slc"),(inst.sc1,"sc1")] if c]
else: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
soffset_s = decode_src(inst.soffset, cdna)
if cdna and not acc and is_mtbuf: return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc}, {soffset_s}{fmt_s}{' ' + ' '.join(mods) if mods else ''}"
return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc},{fmt_s} {soffset_s}{' ' + ' '.join(mods) if mods else ''}"
def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
base = [1, 2, 3, 3, 2, 3, 3, 4][dim]
grad = [1, 2, 3, 2, 1, 2, 2, 2][dim]
if 'get_resinfo' in name: return 1
packed, unpacked = 0, 0
if '_mip' in name: packed += 1
elif 'sample' in name or 'gather' in name:
if '_o' in name: unpacked += 1
if re.search(r'_c(_|$)', name): unpacked += 1
if '_d' in name: unpacked += (grad + 1) & ~1 if '_g16' in name else grad*2
if '_b' in name: unpacked += 1
if '_l' in name and '_cl' not in name and '_lz' not in name: packed += 1
if '_cl' in name: packed += 1
return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked
def _disasm_mimg(inst: MIMG) -> str:
name = inst.op_name.lower()
srsrc_base = inst.srsrc * 4
srsrc_str = _sreg_or_ttmp(srsrc_base, 8)
if 'bvh' in name:
vaddr = (9 if '64' in name else 8) if inst.a16 else (12 if '64' in name else 11)
return f"{name} {_vreg(inst.vdata, 4)}, {_vreg(inst.vaddr, vaddr)}, {_sreg_or_ttmp(srsrc_base, 4)}{' a16' if inst.a16 else ''}"
vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1)
if inst.d16: vdata = (vdata + 1) // 2
if inst.tfe: vdata += 1
dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"dim_{inst.dim}"
vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16)
vaddr_str = f"v{inst.vaddr}" if vaddr == 1 else _vreg(inst.vaddr, vaddr)
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask and (inst.dmask != 15 or 'atomic' in name) else []
mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}")
for flag, mod in [(inst.unrm,"unorm"),(inst.glc,"glc"),(inst.slc,"slc"),(inst.dlc,"dlc"),(inst.r128,"r128"),
(inst.a16,"a16"),(inst.tfe,"tfe"),(inst.lwe,"lwe"),(inst.d16,"d16")]:
if flag: mods.append(mod)
ssamp_str = ""
if 'sample' in name or 'gather' in name or 'get_lod' in name:
ssamp_str = ", " + _sreg_or_ttmp(inst.ssamp * 4, 4)
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}"
def _disasm_sop1(inst: SOP1) -> str:
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)
if not cdna:
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {src}"
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {src}"
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
sop1_src_only = ('S_ALLOC_VGPR', 'S_SLEEP_VAR', 'S_BARRIER_SIGNAL', 'S_BARRIER_SIGNAL_ISFIRST', 'S_BARRIER_INIT', 'S_BARRIER_JOIN')
if inst.op_name in sop1_src_only: return f"{name} {src}"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, {src}"
def _disasm_sop2(inst: SOP2) -> str:
cdna, name = _is_cdna(inst), inst.op_name.lower()
lit = getattr(inst, '_literal', None)
s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)
s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna)
dst = _fmt_sdst(inst.sdst, inst.dst_regs(), cdna)
if 'fmamk' in name and lit is not None: return f"{name} {dst}, {s0}, 0x{lit:x}, {s1}"
if 'fmaak' in name and lit is not None: return f"{name} {dst}, {s0}, {s1}, 0x{lit:x}"
return f"{name} {dst}, {s0}, {s1}"
def _disasm_sopc(inst: SOPC) -> str:
cdna = _is_cdna(inst)
s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)
s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna)
return f"{inst.op_name.lower()} {s0}, {s1}"
def _disasm_sopk(inst: SOPK) -> str:
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
is_rdna4 = 'rdna4' in inst.__class__.__module__
hw = HWREG
def fmt_hwreg(hid, hoff, hsz):
if hid not in hw: return f"0x{inst.simm16:x}"
hr_name = str(hid) if is_rdna4 else hw[hid]
return f"hwreg({hr_name})" if hoff == 0 and hsz == 32 else f"hwreg({hr_name}, {hoff}, {hsz})"
if name == 's_setreg_imm32_b32' or (not cdna and op == SOPKOp.S_SETREG_IMM32_B32):
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
return f"{name} {fmt_hwreg(hid, hoff, hsz)}, 0x{inst._literal:x}"
if not cdna and op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}"
if (not cdna and op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32)) or (cdna and name in ('s_setreg_b32', 's_getreg_b32')):
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
hs = fmt_hwreg(hid, hoff, hsz)
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1, cdna)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1, cdna)}, {hs}"
if not cdna and op in (SOPKOp.S_SUBVECTOR_LOOP_BEGIN, SOPKOp.S_SUBVECTOR_LOOP_END):
return f"{name} {_fmt_sdst(inst.sdst, 1)}, 0x{inst.simm16:x}"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, 0x{inst.simm16:x}"
def _disasm_vinterp(inst: VINTERP) -> str:
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
return f"{inst.op_name.lower()} v{inst.vdst}, {inst.lit(inst.src0, inst.neg & 1)}, {inst.lit(inst.src1, inst.neg & 2)}, {inst.lit(inst.src2, inst.neg & 4)}" + (" " + mods if mods else "")
EXP_TARGETS = {0: 'mrt0', 1: 'mrt1', 2: 'mrt2', 3: 'mrt3', 4: 'mrt4', 5: 'mrt5', 6: 'mrt6', 7: 'mrt7',
8: 'mrtz', 9: 'null', 12: 'pos0', 13: 'pos1', 14: 'pos2', 15: 'pos3', 16: 'pos4',
32: 'param0', 33: 'param1', 34: 'param2', 35: 'param3', 36: 'param4', 37: 'param5'}
def _disasm_vexport(inst) -> str:
tgt = EXP_TARGETS.get(inst.target, f'{inst.target}')
srcs = [f'v{getattr(inst, f"vsrc{i}")}' if inst.en & (1 << i) else 'off' for i in range(4)]
mods = _mods((inst.done, "done"), (inst.row, "row_en"))
return f"export {tgt} {', '.join(srcs)}" + (" " + mods if mods else "")
def _disasm_vbuffer(inst) -> str:
name = inst.op_name.lower().replace('buffer_', 'buffer_').replace('tbuffer_', 'tbuffer_')
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \
{'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], inst.dst_regs())
if getattr(inst, 'tfe', 0): w += 1
vdata = _vreg(inst.vdata, w) if w else f'v{inst.vdata}'
vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else (f'v{inst.vaddr}' if inst.offen or inst.idxen else 'off')
srsrc = f'ttmp[{inst.rsrc - 108}:{inst.rsrc - 108 + 3}]' if inst.rsrc >= 108 else f's[{inst.rsrc}:{inst.rsrc + 3}]'
soff = decode_src(inst.soffset) if inst.soffset >= 106 else f's{inst.soffset}'
fmt = getattr(inst, 'format', 0)
fmt_names = {e.value: e.name for e in BufFmt}
fmt_s = f" format:[{fmt_names[fmt]}]" if fmt > 1 and fmt in fmt_names else (f" format:{fmt}" if fmt > 1 else "")
if 'atomic' in name: th_names = {1: 'TH_ATOMIC_RETURN', 6: 'TH_ATOMIC_CASCADE_NT'}
elif 'store' in name: th_names = {3: 'TH_STORE_BYPASS', 6: 'TH_STORE_NT_HT'}
else: th_names = {3: 'TH_LOAD_BYPASS', 6: 'TH_LOAD_NT_HT'}
scope_names = {1: 'SCOPE_SE', 2: 'SCOPE_DEV', 3: 'SCOPE_SYS'}
mods = _mods((inst.idxen, "idxen"), (inst.offen, "offen"), (inst.ioffset, f"offset:{inst.ioffset}"),
(inst.th in th_names, f"th:{th_names.get(inst.th, '')}"), (inst.scope in scope_names, f"scope:{scope_names.get(inst.scope, '')}"))
return f"{name} {vdata}, {vaddr}, {srsrc}, {soff}{fmt_s}" + (" " + mods if mods else "")
DISASM_HANDLERS: dict[type, callable] = {
VOP1: _disasm_vop1, VOP2: _disasm_vop2, VOPC: _disasm_vopc, VOP3: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOPD: _disasm_vopd, VOP3P: _disasm_vop3p,
VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, MUBUF: _disasm_buf, MTBUF: _disasm_buf,
MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk,
# RDNA4
R4_VOP1: _disasm_vop1, R4_VOP2: _disasm_vop2, R4_VOPC: _disasm_vopc, R4_VOP3: _disasm_vop3, R4_VOP3SD: _disasm_vop3sd,
R4_VOPD: _disasm_vopd, R4_VOP3P: _disasm_vop3p, R4_VINTERP: _disasm_vinterp, R4_SOPP: _disasm_sopp, R4_SMEM: _disasm_smem,
R4_DS: _disasm_ds, R4_SOP1: _disasm_sop1, R4_SOP2: _disasm_sop2, R4_SOPC: _disasm_sopc, R4_SOPK: _disasm_sopk,
R4_VEXPORT: _disasm_vexport, R4_VBUFFER: _disasm_vbuffer}
def disasm(inst: Inst) -> str: return DISASM_HANDLERS[type(inst)](inst)
# ═══════════════════════════════════════════════════════════════════════════════
# CDNA DISASSEMBLER SUPPORT
# ═══════════════════════════════════════════════════════════════════════════════
try:
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as CDNA_VOP1, VOP2 as CDNA_VOP2, VOPC as CDNA_VOPC, VOP3A, VOP3B, VOP3P as CDNA_VOP3P,
SOP1 as CDNA_SOP1, SOP2 as CDNA_SOP2, SOPC as CDNA_SOPC, SOPK as CDNA_SOPK, SOPP as CDNA_SOPP, SMEM as CDNA_SMEM, DS as CDNA_DS,
FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op, VOPCOp as CDNA_VOPCOp)
def _cdna_src(inst, v, neg, abs_=0, n=1):
s = inst.lit(v) if v == 255 else _fmt_src(v, n, cdna=True)
if abs_: s = f"|{s}|"
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
_CDNA_VOP3_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32'}
def _disasm_vop3a(inst) -> str:
op_val = inst._values.get('op', 0)
if hasattr(op_val, 'value'): op_val = op_val.value
name = inst.op_name.lower() or f'vop3a_op_{op_val}'
from extra.assembly.amd.dsl import spec_num_srcs, spec_regs
n = spec_num_srcs(name) if name else inst.num_srcs()
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
orig_name = name
name = _CDNA_VOP3_ALIASES.get(name, name)
if name != orig_name:
s0, s1 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, 1), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, 1)
s2 = ""
dst = f"v{inst.vdst}"
else:
dregs, r0, r1, r2 = spec_regs(name) if name else (inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2))
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, r0), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, r1), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, r2)
dst = _vreg(inst.vdst, dregs) if dregs > 1 else f"v{inst.vdst}"
if op_val >= 512:
return f"{name} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{cl}{om}"
if op_val < 256:
sdst = _fmt_sdst(inst.vdst, 2, cdna=True)
return f"{name}_e64 {sdst}, {s0}, {s1}{cl}"
if 320 <= op_val < 512:
if name in ('v_nop', 'v_clrexcp'): return f"{name}_e64"
return f"{name}_e64 {dst}, {s0}{cl}{om}"
if name == 'v_cndmask_b32':
s2 = _fmt_src(inst.src2, 2, cdna=True)
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{cl}{om}"
if name in ('v_mul_legacy_f32', 'v_mac_f32'):
return f"{name}_e64 {dst}, {s0}, {s1}{cl}{om}"
suf = "_e64" if op_val < 512 else ""
return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}"
def _disasm_vop3b(inst) -> str:
op_val = inst._values.get('op', 0)
if hasattr(op_val, 'value'): op_val = op_val.value
name = inst.op_name.lower() or f'vop3b_op_{op_val}'
from extra.assembly.amd.dsl import spec_num_srcs, spec_regs
n = spec_num_srcs(name) if name else inst.num_srcs()
dregs, r0, r1, r2 = spec_regs(name) if name else (inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2))
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, n=r0), _cdna_src(inst, inst.src1, inst.neg&2, n=r1), _cdna_src(inst, inst.src2, inst.neg&4, n=r2)
dst = _vreg(inst.vdst, dregs) if dregs > 1 else f"v{inst.vdst}"
sdst = _fmt_sdst(inst.sdst, 2, cdna=True)
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
s2 = _fmt_src(inst.src2, 2, cdna=True)
return f"{name}_e64 {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}"
suf = "_e64" if 'co_' in name else ""
return f"{name}{suf} {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {sdst}, {s0}, {s1}{cl}{om}"
def _disasm_cdna_vop3p(inst) -> str:
name, n, is_mfma = inst.op_name.lower(), inst.num_srcs(), 'mfma' in inst.op_name.lower() or 'smfmac' in inst.op_name.lower()
get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc, cdna=True)
if is_mfma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4; src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16), _vreg(inst.vdst, 16)
else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}"
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
_SEL = {0: 'BYTE_0', 1: 'BYTE_1', 2: 'BYTE_2', 3: 'BYTE_3', 4: 'WORD_0', 5: 'WORD_1', 6: 'DWORD'}
_UNUSED = {0: 'UNUSED_PAD', 1: 'UNUSED_SEXT', 2: 'UNUSED_PRESERVE'}
_DPP = {0x130: "wave_shl:1", 0x134: "wave_rol:1", 0x138: "wave_shr:1", 0x13c: "wave_ror:1", 0x140: "row_mirror", 0x141: "row_half_mirror", 0x142: "row_bcast:15", 0x143: "row_bcast:31"}
def _sdwa_src0(v, is_sgpr, sext=0, neg=0, abs_=0):
s = decode_src(v, cdna=True) if is_sgpr else f"v{v}"
if sext: s = f"sext({s})"
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
def _sdwa_vsrc1(v, sext=0, neg=0, abs_=0):
s = f"v{v}"
if sext: s = f"sext({s})"
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
_OMOD_SDWA = {0: "", 1: " mul:2", 2: " mul:4", 3: " div:2"}
def _disasm_sdwa(inst) -> str:
vop2_op = inst.vop2_op
src0 = _sdwa_src0(inst.src0, inst.s0, inst.src0_sext, inst.src0_neg, inst.src0_abs)
clamp = " clamp" if inst.clmp else ""
omod = _OMOD_SDWA.get(inst.omod, "")
if vop2_op == 63:
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
except ValueError: name = f"vop1_op_{inst.vop_op}"
dst = f"v{inst.vdst}"
mods = [f"dst_sel:{_SEL[inst.dst_sel]}", f"dst_unused:{_UNUSED[inst.dst_u]}", f"src0_sel:{_SEL[inst.src0_sel]}"]
return f"{name}_sdwa {dst}, {src0}{clamp}{omod} " + " ".join(mods)
elif vop2_op == 62:
try: name = CDNA_VOPCOp(inst.vdst).name.lower()
except ValueError: name = f"vopc_op_{inst.vdst}"
src1 = _sdwa_vsrc1(inst.vop_op, inst.src1_sext, inst.src1_neg, inst.src1_abs)
sdst_enc = inst.dst_sel | (inst.dst_u << 3) | (inst.clmp << 5) | (inst.omod << 6)
if sdst_enc == 0: sdst = "vcc"
else:
sdst_val = sdst_enc - 128 if sdst_enc >= 128 else sdst_enc
sdst = _fmt_sdst(sdst_val, 2, cdna=True)
mods = [f"src0_sel:{_SEL[inst.src0_sel]}", f"src1_sel:{_SEL[inst.src1_sel]}"]
return f"{name}_sdwa {sdst}, {src0}, {src1} " + " ".join(mods)
else:
try: name = CDNA_VOP2Op(vop2_op).name.lower()
except ValueError: name = f"vop2_op_{vop2_op}"
name = _CDNA_DISASM_ALIASES.get(name, name)
dst = f"v{inst.vdst}"
src1 = _sdwa_vsrc1(inst.vop_op, inst.src1_sext, inst.src1_neg, inst.src1_abs)
mods = [f"dst_sel:{_SEL[inst.dst_sel]}", f"dst_unused:{_UNUSED[inst.dst_u]}", f"src0_sel:{_SEL[inst.src0_sel]}", f"src1_sel:{_SEL[inst.src1_sel]}"]
if name == 'v_cndmask_b32':
return f"{name}_sdwa {dst}, {src0}, {src1}, vcc{clamp}{omod} " + " ".join(mods)
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
return f"{name}_sdwa {dst}, vcc, {src0}, {src1}, vcc{clamp}{omod} " + " ".join(mods)
if '_co_' in name:
return f"{name}_sdwa {dst}, vcc, {src0}, {src1}{clamp}{omod} " + " ".join(mods)
return f"{name}_sdwa {dst}, {src0}, {src1}{clamp}{omod} " + " ".join(mods)
def _dpp_src(v, neg=0, abs_=0):
s = f"v{v}" if v < 256 else f"v{v - 256}"
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
def _disasm_dpp(inst) -> str:
vop2_op = inst.vop2_op
ctrl = inst.dpp_ctrl
dpp = f"quad_perm:[{ctrl&3},{(ctrl>>2)&3},{(ctrl>>4)&3},{(ctrl>>6)&3}]" if ctrl < 0x100 else f"row_shl:{ctrl&0xf}" if ctrl < 0x110 else f"row_shr:{ctrl&0xf}" if ctrl < 0x120 else f"row_ror:{ctrl&0xf}" if ctrl < 0x130 else _DPP.get(ctrl, f"dpp_ctrl:0x{ctrl:x}")
src0 = _dpp_src(inst.src0, inst.src0_neg, inst.src0_abs)
mods = [dpp, f"row_mask:0x{inst.row_mask:x}", f"bank_mask:0x{inst.bank_mask:x}"] + (["bound_ctrl:0"] if inst.bound_ctrl else [])
if vop2_op == 63:
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
except ValueError: name = f"vop1_op_{inst.vop_op}"
return f"{name}_dpp v{inst.vdst}, {src0} " + " ".join(mods)
else:
try: name = CDNA_VOP2Op(vop2_op).name.lower()
except ValueError: name = f"vop2_op_{vop2_op}"
name = _CDNA_DISASM_ALIASES.get(name, name)
src1 = _dpp_src(inst.vop_op, inst.src1_neg, inst.src1_abs)
if name == 'v_cndmask_b32':
return f"{name}_dpp v{inst.vdst}, {src0}, {src1}, vcc " + " ".join(mods)
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
return f"{name}_dpp v{inst.vdst}, vcc, {src0}, {src1}, vcc " + " ".join(mods)
if '_co_' in name:
return f"{name}_dpp v{inst.vdst}, vcc, {src0}, {src1} " + " ".join(mods)
return f"{name}_dpp v{inst.vdst}, {src0}, {src1} " + " ".join(mods)
DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP2: _disasm_vop2, CDNA_VOPC: _disasm_vopc,
CDNA_SOP1: _disasm_sop1, CDNA_SOP2: _disasm_sop2, CDNA_SOPC: _disasm_sopc, CDNA_SOPK: _disasm_sopk, CDNA_SOPP: _disasm_sopp,
CDNA_SMEM: _disasm_smem, CDNA_DS: _disasm_ds, CDNA_FLAT: _disasm_flat, CDNA_MUBUF: _disasm_buf, CDNA_MTBUF: _disasm_buf,
VOP3A: _disasm_vop3a, VOP3B: _disasm_vop3b, CDNA_VOP3P: _disasm_cdna_vop3p, SDWA: _disasm_sdwa, DPP: _disasm_dpp})
except ImportError:
pass