tinygrad/extra/assembly/amd/asm.py
George Hotz 8c14d9f427 rdna4
2026-01-01 17:14:52 +00:00

1036 lines
64 KiB
Python

# RDNA3 assembler and disassembler
from __future__ import annotations
import re
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory, SRC_FIELDS, unwrap
from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, FLOAT_DEC, FLOAT_ENC, decode_src
from extra.assembly.amd.autogen.rdna3 import ins
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP, LDSDIR,
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, VINTERPOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp)
# VOP3SD opcodes that share VOP3 encoding
VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
def _matches_encoding(word: int, cls: type[Inst]) -> bool:
"""Check if word matches the encoding pattern of an instruction class."""
if cls._encoding is None: return False
bf, val = cls._encoding
return ((word >> bf.lo) & bf.mask()) == val
# Order matters: more specific encodings first, VOP2 last (it's a catch-all for bit31=0)
_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP]
_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls
def detect_format(data: bytes) -> type[Inst]:
"""Detect instruction format from machine code bytes."""
assert len(data) >= 4, f"need at least 4 bytes, got {len(data)}"
word = int.from_bytes(data[:4], 'little')
# Check 64-bit formats first (bits[31:30] == 0b11)
if (word >> 30) == 0b11:
for cls in _FORMATS_64:
if _matches_encoding(word, cls):
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in VOP3SD_OPS else cls
raise ValueError(f"unknown 64-bit format word={word:#010x}")
# 32-bit formats
for cls in _FORMATS_32:
if _matches_encoding(word, cls): return cls
raise ValueError(f"unknown 32-bit format word={word:#010x}")
# ═══════════════════════════════════════════════════════════════════════════════
# CONSTANTS
# ═══════════════════════════════════════════════════════════════════════════════
# GFX11 HWREG IDs
HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_HW_ID', 5: 'HW_REG_GPR_ALLOC',
6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS', 15: 'HW_REG_SH_MEM_BASES', 18: 'HW_REG_PERF_SNAPSHOT_PC_LO',
19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK',
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'}
# GFX12 HWREG IDs - use names that LLVM recognizes
HWREG_GFX12 = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 5: 'HW_REG_GPR_ALLOC', 6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS',
18: 'HW_REG_EXCP_FLAG_USER', 19: 'HW_REG_TRAP_CTRL', 20: 'HW_REG_SCRATCH_BASE_LO', 21: 'HW_REG_SCRATCH_BASE_HI',
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 29: 'HW_REG_SHADER_CYCLES_LO', 30: 'HW_REG_SHADER_CYCLES_HI'}
HWREG_IDS = {v.lower(): k for k, v in HWREG.items()}
MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA',
131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'}
# ═══════════════════════════════════════════════════════════════════════════════
# HELPERS
# ═══════════════════════════════════════════════════════════════════════════════
def _reg(p: str, b: int, n: int = 1) -> str: return f"{p}{b}" if n == 1 else f"{p}[{b}:{b+n-1}]"
def _sreg(b: int, n: int = 1) -> str: return _reg("s", b, n)
def _vreg(b: int, n: int = 1) -> str: return _reg("v", b, n)
def _ttmp(b: int, n: int = 1) -> str: return _reg("ttmp", b - 108, n) if 108 <= b <= 123 else None
def _sreg_or_ttmp(b: int, n: int = 1) -> str: return _ttmp(b, n) or _sreg(b, n)
def _fmt_sdst(v: int, n: int = 1) -> str:
if v == 124: return "null"
if t := _ttmp(v, n): return t
if n > 1: return SPECIAL_PAIRS.get(v) or _sreg(v, n)
return SPECIAL_GPRS.get(v, f"s{v}")
def _fmt_src(v: int, n: int = 1) -> str:
if n == 1: return decode_src(v)
if v >= 256: return _vreg(v - 256, n)
if v <= 105: return _sreg(v, n)
if n == 2 and v in SPECIAL_PAIRS: return SPECIAL_PAIRS[v]
if t := _ttmp(v, n): return t
return decode_src(v)
def _fmt_v16(v: int, base: int = 256, hi_thresh: int = 384) -> str:
return f"v{(v - base) & 0x7f}.{'h' if v >= hi_thresh else 'l'}"
def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int:
return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
def _has(op: str, *subs) -> bool: return any(s in op for s in subs)
def _is16(op: str) -> bool: return _has(op, 'f16', 'i16', 'u16', 'b16') and not _has(op, '_f32', '_i32')
def _is64(op: str) -> bool: return _has(op, 'f64', 'i64', 'u64', 'b64')
def _omod(v: int) -> str: return {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(v, "")
def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v) # format 16-bit src: vgpr.h/l or literal
def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c)
def _fmt_bits(label: str, val: int, count: int) -> str: return f"{label}:[{','.join(str((val >> i) & 1) for i in range(count))}]"
def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any_hi: bool) -> str:
"""Format VOP3 source operand with modifiers."""
if n > 1: s = _fmt_src(v, n)
elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else (f"v{v - 256}.l" if any_hi else inst.lit(v))
else: s = inst.lit(v)
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
"""Format op_sel modifier string."""
if not need: return ""
if is16_d and (opsel & 8): return f" op_sel:[1,1,1{',1' if n == 3 else ''}]"
if n == 3: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{(opsel >> 3) & 1}]"
return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1}]"
# ═══════════════════════════════════════════════════════════════════════════════
# DISASSEMBLER
# ═══════════════════════════════════════════════════════════════════════════════
_VOP1_F64 = {VOP1Op.V_CEIL_F64, VOP1Op.V_FLOOR_F64, VOP1Op.V_FRACT_F64, VOP1Op.V_FREXP_MANT_F64, VOP1Op.V_RCP_F64, VOP1Op.V_RNDNE_F64, VOP1Op.V_RSQ_F64, VOP1Op.V_SQRT_F64, VOP1Op.V_TRUNC_F64}
def _disasm_vop1(inst: VOP1) -> str:
# Use architecture-specific op enum
if 'rdna4' in inst.__class__.__module__:
from extra.assembly.amd.autogen.rdna4.enum import VOP1Op as OpEnum
else:
OpEnum = VOP1Op
op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower()
if name in ('v_nop', 'v_pipeflush'): return name
if name == 'v_readfirstlane_b32': return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}"
parts = name.split('_')
is_f64_d = 'f64' in name and any(x in name for x in ['ceil', 'floor', 'fract', 'frexp_mant', 'rcp', 'rndne', 'rsq', 'sqrt', 'trunc', 'cvt_f64_f32', 'cvt_f64_i32', 'cvt_f64_u32'])
is_f64_s = 'f64' in name and any(x in name for x in ['ceil', 'floor', 'fract', 'frexp_mant', 'rcp', 'rndne', 'rsq', 'sqrt', 'trunc', 'cvt_f32_f64', 'cvt_i32_f64', 'cvt_u32_f64', 'frexp_exp_i32_f64'])
# v_cvt_pk_f32_bf8/fp8 output 2 VGPRs (packed f32x2) and take packed 8-bit (16-bit VGPR with .l/.h) source
is_pk_f32 = 'cvt_pk_f32_bf8' in name or 'cvt_pk_f32_fp8' in name
is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)
# Only packed bf8/fp8 (cvt_pk_*) use 16-bit VGPR encoding; non-packed versions use regular VGPRs
is_16s = (parts[-1] in ('f16','i16','u16','b16') and 'sat_pk' not in name) or (parts[-1] in ('bf8', 'fp8') and 'pk' in name)
dst = _vreg(inst.vdst, 2) if is_f64_d or is_pk_f32 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
src = _fmt_src(inst.src0, 2) if is_f64_s else _src16(inst, inst.src0) if is_16s else inst.lit(inst.src0)
return f"{name}_e32 {dst}, {src}"
def _disasm_vop2(inst: VOP2) -> str:
# Use architecture-specific op enum
if 'rdna4' in inst.__class__.__module__:
from extra.assembly.amd.autogen.rdna4.enum import VOP2Op as OpEnum
else:
OpEnum = VOP2Op
op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower()
suf, is16 = "" if name == 'v_dot2acc_f32_f16' else "_e32", _is16(name) and 'pk_' not in name
is64 = _is64(name)
# For shift ops with b64, src0 is 32-bit (shift amount), dst/vsrc1 are 64-bit
is_shift64 = 'lshlrev_b64' in name
# fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1
if 'fmaak' in name: return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}"
if 'fmamk' in name: return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}"
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
if is_shift64: return f"{name}{suf} {_vreg(inst.vdst, 2)}, {inst.lit(inst.src0)}, {_vreg(inst.vsrc1, 2)}"
if is64: return f"{name}{suf} {_vreg(inst.vdst, 2)}, {_fmt_src(inst.src0, 2)}, {_vreg(inst.vsrc1, 2)}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if name == 'v_cndmask_b32' else "")
def _disasm_vopc(inst: VOPC) -> str:
# Use architecture-specific op enum
if 'rdna4' in inst.__class__.__module__:
from extra.assembly.amd.autogen.rdna4.enum import VOPCOp as OpEnum
else:
OpEnum = VOPCOp
op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower()
is64, is16 = _is64(name), _is16(name)
is_class = 'class' in name
s0 = _fmt_src(inst.src0, 2) if is64 else _src16(inst, inst.src0) if is16 else inst.lit(inst.src0)
s1 = _vreg(inst.vsrc1, 2) if is64 and not is_class else _fmt_v16(inst.vsrc1, 0, 128) if is16 else f"v{inst.vsrc1}"
return f"{name}_e32 {s0}, {s1}" if op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
NO_ARG_SOPP = {SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE}
def _disasm_sopp(inst: SOPP) -> str:
op, name = SOPPOp(inst.op), SOPPOp(inst.op).name.lower()
if op in NO_ARG_SOPP: return name
if op == SOPPOp.S_WAITCNT:
vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f
p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
if op == SOPPOp.S_DELAY_ALU:
deps, skips = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3'], ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4']
id0, skip, id1 = inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x7, (inst.simm16 >> 7) & 0xf
dep = lambda v: deps[v-1] if 0 < v <= len(deps) else str(v)
p = [f"instid0({dep(id0)})" if id0 else "", f"instskip({skips[skip]})" if skip else "", f"instid1({dep(id1)})" if id1 else ""]
return f"s_delay_alu {' | '.join(x for x in p if x) or '0'}"
return f"{name} {inst.simm16}" if name.startswith(('s_cbranch', 's_branch')) else f"{name} 0x{inst.simm16:x}"
def _disasm_smem(inst: SMEM) -> str:
is_rdna4 = 'rdna4' in inst.__class__.__module__
if is_rdna4:
from extra.assembly.amd.autogen.rdna4.enum import SMEMOp as SMEMOp4
op = SMEMOp4(inst.op)
name = op.name.lower()
if op == SMEMOp4.S_DCACHE_INV: return name
# RDNA4: s_buffer_* uses 4-SGPR descriptor, s_load/s_prefetch uses 2-SGPR
is_buffer = 'buffer' in name
sbase_idx = inst.sbase * 2
sbase_count = 4 if is_buffer else 2
sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
# Format offset - ioffset is signed 24-bit, show as hex
ioff = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000 # sign extend
off_hex = f"0x{ioff & 0xffffff:x}" if ioff >= 0 else f"-0x{(-ioff) & 0xffffff:x}"
off_s = f"{decode_src(inst.soffset)} offset:{off_hex}" if inst.soffset != 124 else off_hex
# Data width from opcode
width_map = {0:1, 1:2, 2:4, 3:8, 4:16, 5:3, 8:1, 9:1, 10:1, 11:1, 16:1, 17:2, 18:4, 19:8, 20:16, 21:3, 24:1, 25:1, 26:1, 27:1}
width = width_map.get(inst.op, 1)
if 'prefetch' in name:
# Prefetch has different format: s_prefetch_* sbase, offset, soffset, length
# But we need to handle various prefetch types differently
if name == 's_prefetch_inst_pc_rel' or name == 's_prefetch_data_pc_rel':
return f"{name} {off_hex}, {decode_src(inst.soffset)}, {inst.sdata}"
return f"{name} {sbase_str}, {off_hex}, {decode_src(inst.soffset)}, {inst.sdata}"
return f"{name} {_fmt_sdst(inst.sdata, width)}, {sbase_str}, {off_s}"
else:
op = SMEMOp(inst.op)
name = op.name.lower()
if op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset)
sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op <= 12 or name == 's_atc_probe_buffer') else 2
sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}"
width = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(inst.op, 1)
return f"{name} {_fmt_sdst(inst.sdata, width)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
def _disasm_flat(inst: FLAT) -> str:
name = FLATOp(inst.op).name.lower()
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
suffix = name.split('_')[-1]
w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'f32':1,'f64':2}.get(suffix, 1)
if 'cmpswap' in name: w *= 2
if name.endswith('_x2') or 'x2' in suffix: w = max(w, 2)
mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
# saddr
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
elif inst.saddr == 124: saddr_s = ", off"
elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr)}"
elif inst.saddr in SPECIAL_PAIRS: saddr_s = f", {SPECIAL_PAIRS[inst.saddr]}"
elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}"
else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr)}"
# addtid: no addr
if 'addtid' in name: return f"{instr} v{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}"
# addr width
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2)
data_s, vdst_s = _vreg(inst.data, w), _vreg(inst.vdst, w // 2 if 'cmpswap' in name else w)
if 'atomic' in name:
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if inst.glc else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}"
return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
def _disasm_ds(inst: DS) -> str:
op, name = DSOp(inst.op), DSOp(inst.op).name.lower()
gds = " gds" if inst.gds else ""
off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else ""
off2 = f" offset0:{inst.offset0} offset1:{inst.offset1}" if inst.offset0 or inst.offset1 else ""
w = 4 if '128' in name else 3 if '96' in name else 2 if (name.endswith('64') or 'gs_reg' in name) else 1
d0, d1, dst, addr = _vreg(inst.data0, w), _vreg(inst.data1, w), _vreg(inst.vdst, w), f"v{inst.addr}"
if op == DSOp.DS_NOP: return name
if op == DSOp.DS_BVH_STACK_RTN_B32: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}"
if 'gws_sema' in name and op != DSOp.DS_GWS_SEMA_BR: return f"{name}{off}{gds}"
if 'gws_' in name: return f"{name} {addr}{off}{gds}"
if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} v{inst.vdst}{off}{gds}"
if 'gs_reg' in name: return f"{name} {_vreg(inst.vdst, 2)}, v{inst.data0}{off}{gds}"
if '2addr' in name:
if 'load' in name: return f"{name} {_vreg(inst.vdst, w*2)}, {addr}{off2}{gds}"
if 'store' in name and 'xchg' not in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
return f"{name} {_vreg(inst.vdst, w*2)}, {addr}, {d0}, {d1}{off2}{gds}"
if 'load' in name: return f"{name} v{inst.vdst}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
if 'store' in name and not _has(name, 'cmp', 'xchg'):
return f"{name} v{inst.data0}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}"
if 'swizzle' in name or op == DSOp.DS_ORDERED_COUNT: return f"{name} v{inst.vdst}, {addr}{off}{gds}"
if 'permute' in name: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}{off}{gds}"
if 'condxchg' in name: return f"{name} {_vreg(inst.vdst, 2)}, {addr}, {_vreg(inst.data0, 2)}{off}{gds}"
if _has(name, 'cmpstore', 'mskor', 'wrap'):
return f"{name} {dst}, {addr}, {d0}, {d1}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}, {d1}{off}{gds}"
return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}"
def _disasm_vop3(inst: VOP3) -> str:
is_rdna4 = 'rdna4' in inst.__class__.__module__
if is_rdna4:
from extra.assembly.amd.autogen.rdna4.enum import VOP3Op as VOP3Op4, VOP3SDOp as VOP3SDOp4
op = VOP3SDOp4(inst.op) if inst.op in VOP3SD_OPS else VOP3Op4(inst.op)
else:
op = VOP3SDOp(inst.op) if inst.op in VOP3SD_OPS else VOP3Op(inst.op)
name = op.name.lower()
# VOP3SD (shared encoding)
if inst.op in VOP3SD_OPS:
sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs
is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32', 'mad_co_i64_i32', 'mad_co_u64_u32')
def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s
s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64)
dst = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}"
if op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}"
if op in (VOP3SDOp.V_ADD_CO_CI_U32, VOP3SDOp.V_SUB_CO_CI_U32, VOP3SDOp.V_SUBREV_CO_CI_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}"
return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}" + _omod(inst.omod)
# Detect operand sizes
is64 = _is64(name)
is64_src, is64_dst = False, False
is16_d = is16_s = is16_s2 = False
# v_cvt_pk_f32_bf8/fp8 outputs a VGPR pair (f32x2) from 16-bit packed input
if 'cvt_pk_f32_bf8' in name or 'cvt_pk_f32_fp8' in name: is64_dst = True
elif 'cvt_pk' in name: is16_s = name.endswith('16')
elif m := re.match(r'v_(?:cvt|frexp_exp)_([a-z0-9_]+)_([a-z0-9]+)', name):
is16_d, is16_s = _has(m.group(1), 'f16','i16','u16','b16'), _has(m.group(2), 'f16','i16','u16','b16')
is64_src, is64_dst = '64' in m.group(2), '64' in m.group(1)
is16_s2, is64 = is16_s, False
elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True
elif 'pack_b32' in name: is16_s = is16_s2 = True
else: is16_d = is16_s = is16_s2 = _is16(name) and not _has(name, 'dot2', 'pk_', 'sad', 'msad', 'qsad', 'mqsad')
# Source counts
shift64 = 'rev' in name and '64' in name and name.startswith('v_')
ldexp64 = op == VOP3Op.V_LDEXP_F64
trig = op == VOP3Op.V_TRIG_PREOP_F64
sad64, mqsad = _has(name, 'qsad_pk', 'mqsad_pk'), 'mqsad_u32' in name
s0n = 2 if ((is64 and not shift64) or sad64 or mqsad or is64_src) else 1
s1n = 2 if (is64 and not _has(name, 'class') and not ldexp64 and not trig) else 1
s2n = 4 if mqsad else 2 if (is64 or sad64) else 1
any_hi = inst.opsel != 0
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, s0n, is16_s, any_hi)
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, s1n, is16_s, any_hi)
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, s2n, is16_s2, any_hi)
# Destination
dn = 4 if mqsad else 2 if (is64 or sad64 or is64_dst) else 1
if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1)
elif dn > 1: dst = _vreg(inst.vdst, dn)
elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" if any_hi else f"v{inst.vdst}"
else: dst = f"v{inst.vdst}"
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
nonvgpr_opsel = (inst.src0 < 256 and (inst.opsel & 1)) or (inst.src1 < 256 and (inst.opsel & 2)) or (inst.src2 < 256 and (inst.opsel & 4))
need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s)
# RDNA4 v_s_* instructions (pseudo-scalar VOP1-like) have SGPR destination
if name.startswith('v_s_') and is_rdna4:
return f"{name} {_fmt_sdst(inst.vdst, 1)}, {s0}{cl}{om}"
if inst.op < 256: # VOPC
return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
if inst.op < 384: # VOP2
os = _opsel_str(inst.opsel, 3, need_opsel, is16_d) if 'cndmask' in name else _opsel_str(inst.opsel, 2, need_opsel, is16_d)
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if 'cndmask' in name else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}"
if inst.op < 512: # VOP1
if name in ('v_nop', 'v_pipeflush'): return f"{name}_e64"
# Handle byte_sel for non-pk fp8/bf8 conversions
if ('cvt_f32_fp8' in name or 'cvt_f32_bf8' in name) and 'pk' not in name:
byte_sel = inst.opsel & 3
os = f" byte_sel:{byte_sel}" if byte_sel else ""
elif 'cvt_pk_f32_bf8' in name or 'cvt_pk_f32_fp8' in name:
os = _opsel_str(inst.opsel, 2, need_opsel, is16_d) # 2-element for pk variants
else:
os = _opsel_str(inst.opsel, 1, need_opsel, is16_d) # 1-element for other VOP1
return f"{name}_e64 {dst}, {s0}{os}{cl}{om}"
# Native VOP3
is3 = _has(name, 'fma', 'mad', 'min3', 'max3', 'med3', 'div_fix', 'div_fmas', 'sad', 'lerp', 'align', 'cube', 'bfe', 'bfi',
'perm_b32', 'cndmask', 'xor3', 'or3', 'add3', 'lshl_or', 'and_or', 'lshl_add', 'add_lshl', 'xad', 'maxmin', 'minmax', 'dot2', 'cvt_pk_u8', 'mullit',
'minimummaximum', 'maximumminimum', 'minimum3', 'maximum3')
# permlane16/permlanex16 have 3 sources, but _var variants have 2
if 'permlane' in name and 'var' not in name: is3 = True
# Handle byte_sel for fp8/bf8 instructions (opsel encodes byte_sel, not op_sel)
# For VOP1-encoded VOP3 (op < 512): cvt_f32_fp8, cvt_f32_bf8
# For native VOP3: cvt_sr_fp8, cvt_sr_bf8
if ('cvt_f32_fp8' in name or 'cvt_f32_bf8' in name or 'cvt_sr_fp8' in name or 'cvt_sr_bf8' in name) and 'pk' not in name:
# For VOP1 encoding (op < 512), byte_sel is in bits[1:0] of opsel; for native VOP3, it's bits[3:2]
byte_sel = (inst.opsel & 3) if inst.op < 512 else ((inst.opsel >> 2) & 3)
os = f" byte_sel:{byte_sel}" if byte_sel else ""
else:
os = _opsel_str(inst.opsel, 3 if is3 else 2, need_opsel, is16_d)
return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if is3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}"
def _disasm_vop3sd(inst: VOP3SD) -> str:
op, name = VOP3SDOp(inst.op), VOP3SDOp(inst.op).name.lower()
is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32')
def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s
s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64)
dst, is2src = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}", op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32)
suffix = "_e64" if name.startswith('v_') and 'co_' in name else ""
return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{'' if is2src else f', {s2}'}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}"
def _disasm_vopd(inst: VOPD) -> str:
is_rdna4 = 'rdna4' in inst.__class__.__module__
if is_rdna4:
from extra.assembly.amd.autogen.rdna4.enum import VOPDOp as VOPDOp4
OpEnum = VOPDOp4
else:
OpEnum = VOPDOp
lit = inst._literal or inst.literal
vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), OpEnum(inst.opx).name.lower(), OpEnum(inst.opy).name.lower()
def half(n, vd, s0, vs1): return f"{n} v{vd}, {inst.lit(s0)}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" if 'mov' in n else f"{n} v{vd}, {inst.lit(s0)}, v{vs1}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}"
return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}"
def _disasm_vop3p(inst: VOP3P, wave_size: int = 32) -> str:
is_rdna4 = 'rdna4' in inst.__class__.__module__
if is_rdna4:
from extra.assembly.amd.autogen.rdna4.enum import VOP3POp as OpEnum
else:
OpEnum = VOP3POp
name = OpEnum(inst.op).name.lower()
is_wmma, is_swmmac = 'wmma' in name and 'swmmac' not in name, 'swmmac' in name
is_3src, is_fma_mix = _has(name, 'fma', 'mad', 'dot', 'wmma'), 'fma_mix' in name
# Wave64 uses half the register widths of wave32 for WMMA
wave_div = 2 if wave_size == 64 else 1
if is_swmmac and is_rdna4:
# SWMMAC (sparse WMMA): src2 is a single VGPR index, not an accumulator
# Determine src0/src1/dst sizes based on instruction type
if 'f16' in name or 'bf16' in name:
if 'f16_16x16x32_f16' in name or 'bf16_16x16x32_bf16' in name:
s0c, s1c, dc = 4, 8, 4 # f16/bf16 output
else:
s0c, s1c, dc = 4, 8, 8 # f32 output
elif 'iu8' in name: s0c, s1c, dc = 2, 4, 8
elif 'iu4' in name:
if '16x16x64' in name: s0c, s1c, dc = 2, 4, 8
else: s0c, s1c, dc = 1, 2, 8
elif 'fp8' in name or 'bf8' in name: s0c, s1c, dc = 2, 4, 8
else: s0c, s1c, dc = 4, 8, 8
s0c, s1c, dc = max(1, s0c // wave_div), max(1, s1c // wave_div), max(1, dc // wave_div)
src0, src1, src2, dst = _fmt_src(inst.src0, s0c), _fmt_src(inst.src1, s1c), _fmt_src(inst.src2, 1), _vreg(inst.vdst, dc)
elif is_wmma:
# RDNA4 WMMA uses smaller source register widths than RDNA3
if is_rdna4:
# RDNA4 wave32 source widths: iu4->1/2, iu8->2, fp8/bf8->2, f16/bf16->4
if 'iu4' in name:
sc = 2 if '16x16x32' in name else 1
elif 'iu8' in name or 'fp8' in name or 'bf8' in name: sc = 2
else: sc = 4 # f16/bf16
# Destination width: f16/bf16 output->4, f32/i32 output->8
dc = 4 if name.startswith('v_wmma_f16') or name.startswith('v_wmma_bf16') else 8
else:
# RDNA3: iu4->2, iu8->4, f16/bf16->8
sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8
dc = 8
sc, dc = max(1, sc // wave_div), max(1, dc // wave_div)
src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, dc), _vreg(inst.vdst, dc)
else: src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}"
n, opsel_hi = 3 if is_3src else 2, inst.opsel_hi | (inst.opsel_hi2 << 2)
if is_swmmac and is_rdna4:
# SWMMAC uses index_key instead of op_sel; opsel bits encode the key value
mods = ([f"index_key:{inst.opsel & 7}"] if inst.opsel else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
elif is_fma_mix:
def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s)
src0, src1, src2 = m(src0, inst.neg & 1, inst.neg_hi & 1), m(src1, inst.neg & 2, inst.neg_hi & 2), m(src2, inst.neg & 4, inst.neg_hi & 4)
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if inst.clmp else [])
else:
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if is_3src else 3) else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if is_3src else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
def _disasm_buf(inst: MUBUF | MTBUF) -> str:
op = MTBUFOp(inst.op) if isinstance(inst, MTBUF) else MUBUFOp(inst.op)
name = op.name.lower()
if op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \
{'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1)
if inst.tfe: w += 1
vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off"
srsrc = _sreg_or_ttmp(inst.srsrc*4, 4)
mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {srsrc}, {decode_src(inst.soffset)}{' ' + ' '.join(mods) if mods else ''}"
def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
"""Calculate vaddr register count for MIMG sample/gather operations."""
# 1d,2d,3d,cube,1d_arr,2d_arr,2d_msaa,2d_msaa_arr
base = [1, 2, 3, 3, 2, 3, 3, 4][dim] # address coords
grad = [1, 2, 3, 2, 1, 2, 2, 2][dim] # gradient coords (for derivatives)
if 'get_resinfo' in name: return 1 # only mip level
packed, unpacked = 0, 0
if '_mip' in name: packed += 1
elif 'sample' in name or 'gather' in name:
if '_o' in name: unpacked += 1 # offset
if re.search(r'_c(_|$)', name): unpacked += 1 # compare (not _cl)
if '_d' in name: unpacked += (grad + 1) & ~1 if '_g16' in name else grad*2 # derivatives
if '_b' in name: unpacked += 1 # bias
if '_l' in name and '_cl' not in name and '_lz' not in name: packed += 1 # LOD
if '_cl' in name: packed += 1 # clamp
return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked
def _disasm_mimg(inst: MIMG) -> str:
name = MIMGOp(inst.op).name.lower()
srsrc_base = inst.srsrc * 4
srsrc_str = _sreg_or_ttmp(srsrc_base, 8)
# BVH intersect ray: special case with 4 SGPR srsrc
if 'bvh' in name:
vaddr = (9 if '64' in name else 8) if inst.a16 else (12 if '64' in name else 11)
return f"{name} {_vreg(inst.vdata, 4)}, {_vreg(inst.vaddr, vaddr)}, {_sreg_or_ttmp(srsrc_base, 4)}{' a16' if inst.a16 else ''}"
# vdata width from dmask (gather4/msaa_load always 4), d16 packs, tfe adds 1
vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1)
if inst.d16: vdata = (vdata + 1) // 2
if inst.tfe: vdata += 1
# vaddr width
dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"dim_{inst.dim}"
vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16)
vaddr_str = f"v{inst.vaddr}" if vaddr == 1 else _vreg(inst.vaddr, vaddr)
# modifiers - always include dmask for image load/store/atomic (LLVM uses it for vdata size validation)
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask else []
mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}")
for flag, mod in [(inst.unrm,"unorm"),(inst.glc,"glc"),(inst.slc,"slc"),(inst.dlc,"dlc"),(inst.r128,"r128"),
(inst.a16,"a16"),(inst.tfe,"tfe"),(inst.lwe,"lwe"),(inst.d16,"d16")]:
if flag: mods.append(mod)
# ssamp for sample/gather/get_lod
ssamp_str = ""
if 'sample' in name or 'gather' in name or 'get_lod' in name:
ssamp_str = ", " + _sreg_or_ttmp(inst.ssamp * 4, 4)
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}"
def _sop_widths(name: str) -> tuple[int, int, int]:
"""Return (dst_width, src0_width, src1_width) in register count for SOP instructions."""
if name in ('s_bitset0_b64', 's_bitset1_b64', 's_bfm_b64'): return 2, 1, 1
if name in ('s_lshl_b64', 's_lshr_b64', 's_ashr_i64', 's_bfe_u64', 's_bfe_i64'): return 2, 2, 1
if name in ('s_bitcmp0_b64', 's_bitcmp1_b64'): return 1, 2, 1
if m := re.search(r'_(b|i|u)(32|64)_(b|i|u)(32|64)$', name): return 2 if m.group(2) == '64' else 1, 2 if m.group(4) == '64' else 1, 1
if m := re.search(r'_(b|i|u)(32|64)$', name): sz = 2 if m.group(2) == '64' else 1; return sz, sz, sz
return 1, 1, 1
def _disasm_sop1(inst: SOP1) -> str:
op, name = SOP1Op(inst.op), SOP1Op(inst.op).name.lower()
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {_fmt_src(inst.ssrc0, 2)}"
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {_fmt_src(inst.ssrc0, 2)}"
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, 2 if 'b64' in name else 1)}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
dn, s0n, _ = _sop_widths(name)
return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if s0n == 1 else _fmt_src(inst.ssrc0, s0n)}"
def _disasm_sop2(inst: SOP2) -> str:
name = SOP2Op(inst.op).name.lower()
dn, s0n, s1n = _sop_widths(name)
return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, s0n)}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, s1n)}"
def _disasm_sopc(inst: SOPC) -> str:
name = SOPCOp(inst.op).name.lower()
_, s0n, s1n = _sop_widths(name)
return f"{name} {_fmt_src(inst.ssrc0, s0n)}, {_fmt_src(inst.ssrc1, s1n)}"
def _disasm_sopk(inst: SOPK) -> str:
# Use architecture-specific SOPK op enum
if 'rdna4' in inst.__class__.__module__:
from extra.assembly.amd.autogen.rdna4.enum import SOPKOp as OpEnum
hwreg_map = HWREG_GFX12
else:
OpEnum = SOPKOp
hwreg_map = HWREG
op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower()
if name == 's_version': return f"{name} 0x{inst.simm16:x}"
if name in ('s_setreg_b32', 's_getreg_b32'):
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
hreg_name = hwreg_map.get(hid, str(hid))
# If offset=0 and size=32, use short form hwreg(NAME), otherwise hwreg(NAME, off, sz)
if hid in (16, 17): hs = f"0x{inst.simm16:x}"
elif hoff == 0 and hsz == 32: hs = f"hwreg({hreg_name})"
else: hs = f"hwreg({hreg_name}, {hoff}, {hsz})"
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if name == 's_setreg_b32' else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
dn, _, _ = _sop_widths(name)
return f"{name} {_fmt_sdst(inst.sdst, dn)}, 0x{inst.simm16:x}"
def _disasm_vinterp(inst: VINTERP) -> str:
name = VINTERPOp(inst.op).name.lower()
src0 = f"-{inst.lit(inst.src0)}" if inst.neg & 1 else inst.lit(inst.src0)
src1 = f"-{inst.lit(inst.src1)}" if inst.neg & 2 else inst.lit(inst.src1)
src2 = f"-{inst.lit(inst.src2)}" if inst.neg & 4 else inst.lit(inst.src2)
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
return f"{name} v{inst.vdst}, {src0}, {src1}, {src2}" + (" " + mods if mods else "")
# Export targets: mrt0-7, mrtz, pos0-4, prim, dual_src_blend0/1
_EXP_TARGETS = {**{i: f'mrt{i}' for i in range(8)}, 8: 'mrtz', **{i+12: f'pos{i}' for i in range(5)}, 20: 'prim', 21: 'dual_src_blend0', 22: 'dual_src_blend1'}
def _disasm_exp(inst) -> str:
target = _EXP_TARGETS.get(inst.target, f"invalid_target_{inst.target}")
en = inst.en
vsrc = lambda i, v: f"v{v}" if (en >> i) & 1 else "off"
srcs = f"{vsrc(0, inst.vsrc0)}, {vsrc(1, inst.vsrc1)}, {vsrc(2, inst.vsrc2)}, {vsrc(3, inst.vsrc3)}"
mods = _mods((inst.done, "done"), (inst.row, "row_en"))
prefix = "export" if 'rdna4' in inst.__class__.__module__ else "exp"
return f"{prefix} {target} {srcs}" + (" " + mods if mods else "")
def _disasm_ldsdir(inst) -> str:
is_rdna4 = 'rdna4' in inst.__class__.__module__
if is_rdna4:
# RDNA4 uses ds_* prefix and wait_va_vdst/wait_vm_vsrc modifiers
wait = f" wait_va_vdst:{inst.wait_va} wait_vm_vsrc:{inst.wait_vm}"
if inst.op == 1: return f"ds_direct_load v{inst.vdst}{wait}"
if inst.op == 0: return f"ds_param_load v{inst.vdst}, attr{inst.attr}.{['x','y','z','w'][inst.attr_chan]}{wait}"
else:
# RDNA3 uses lds_* prefix and wait_vdst modifier
wait = f" wait_vdst:{inst.wait_va}" if inst.wait_va != 0 else ""
if inst.op == 1: return f"lds_direct_load v{inst.vdst}{wait}"
if inst.op == 0: return f"lds_param_load v{inst.vdst}, attr{inst.attr}.{['x','y','z','w'][inst.attr_chan]}{wait}"
raise ValueError(f"unknown LDSDIR op: {inst.op}")
# ═══════════════════════════════════════════════════════════════════════════════
# RDNA4-specific disassemblers (GFX12)
# ═══════════════════════════════════════════════════════════════════════════════
# th values for RDNA4 memory instructions (based on AMDGPU ISA docs and LLVM SIDefines.h)
# Load: 0=RT(default), 1=NT, 2=HT, 3=LU, 4=NT_RT, 5=RT_NT, 6=NT_HT, 7=BYPASS(only with scope=SYS)
_TH_LOAD = {0: '', 1: 'th:TH_LOAD_NT', 2: 'th:TH_LOAD_HT', 3: 'th:TH_LOAD_LU', 4: 'th:TH_LOAD_NT_RT', 5: 'th:TH_LOAD_RT_NT', 6: 'th:TH_LOAD_NT_HT', 7: 'th:TH_LOAD_BYPASS'}
# Store: 0=RT(default), 1=NT, 2=HT, 3=WB, 4=NT_RT, 5=RT_NT, 6=NT_HT, 7=NT_WB (BYPASS is th=7 + scope=SYS)
_TH_STORE = {0: '', 1: 'th:TH_STORE_NT', 2: 'th:TH_STORE_HT', 3: 'th:TH_STORE_WB', 4: 'th:TH_STORE_NT_RT', 5: 'th:TH_STORE_RT_NT', 6: 'th:TH_STORE_NT_HT', 7: 'th:TH_STORE_NT_WB'}
# Atomic: bit0=RETURN, bit1=NT, bit2=CASCADE -> 0=none, 1=RETURN, 2=NT, 3=NT_RETURN, 4=CASCADE_RT, 5=RT_RETURN(N/A), 6=CASCADE_NT, 7=N/A
_TH_ATOMIC = {0: '', 1: 'th:TH_ATOMIC_RETURN', 2: 'th:TH_ATOMIC_NT', 3: 'th:TH_ATOMIC_NT_RETURN', 4: 'th:TH_ATOMIC_CASCADE_RT', 5: 'th:TH_ATOMIC_RT_RETURN', 6: 'th:TH_ATOMIC_CASCADE_NT', 7: 'th:TH_ATOMIC_CASCADE_NT'}
_SCOPE = {0: '', 1: 'scope:SCOPE_SE', 2: 'scope:SCOPE_DEV', 3: 'scope:SCOPE_SYS'}
def _rdna4_mem_mods(th: int, scope: int, is_store: bool, is_atomic: bool) -> str:
th_map = _TH_ATOMIC if is_atomic else _TH_STORE if is_store else _TH_LOAD
# Special case: th=3 with scope=SYS means BYPASS for load/store (otherwise th=3 means LU/WB)
if th == 3 and scope == 3 and not is_atomic:
th_s = 'th:TH_STORE_BYPASS' if is_store else 'th:TH_LOAD_BYPASS'
else:
th_s = th_map.get(th, f'th:{th}' if th else '')
scope_s = _SCOPE.get(scope, f'scope:{scope}' if scope else '')
return ' '.join(x for x in [th_s, scope_s] if x)
def _disasm_vflat(inst) -> str:
"""Disassemble RDNA4 VFLAT/VGLOBAL/VSCRATCH instructions."""
from extra.assembly.amd.autogen.rdna4.enum import VFLATOp, VGLOBALOp, VSCRATCHOp
cls_name = type(inst).__name__
if cls_name == 'VGLOBAL': op_enum, prefix = VGLOBALOp, 'global'
elif cls_name == 'VSCRATCH': op_enum, prefix = VSCRATCHOp, 'scratch'
else: op_enum, prefix = VFLATOp, 'flat'
name = op_enum(inst.op).name.lower()
# global_wb, global_wbinv, global_inv are cache control instructions with no operands
if name in ('global_wb', 'global_wbinv', 'global_inv'):
mods = _rdna4_mem_mods(inst.th, inst.scope, False, False)
return f"{name}" + (f" {mods}" if mods else "")
# addtid instructions use thread ID as address offset, no vaddr operand
is_addtid = 'addtid' in name
# Data width based on instruction name suffix
suffix = name.split('_')[-1]
# block loads/stores use 32 VGPRs
if 'block' in name:
base_w = 32
else:
base_w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'f32':1,'f64':2}.get(suffix, 1)
# For cmpswap: vsrc holds cmp+data pairs (2x base), vdst is base width
vsrc_w = base_w * 2 if 'cmpswap' in name else base_w
vdst_w = base_w
# Offset: signed 24-bit (stored as unsigned, needs sign extension)
off = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000
off_s = f" offset:{off}" if off else ""
# Memory modifiers
is_store, is_atomic = 'store' in name, 'atomic' in name
mods = _rdna4_mem_mods(inst.th, inst.scope, is_store, is_atomic)
# saddr handling - VGLOBAL and VSCRATCH need explicit "off" when saddr=124
if inst.saddr == 124: saddr_s = "off" if prefix in ('global', 'scratch') else ""
elif inst.saddr in SPECIAL_PAIRS: saddr_s = SPECIAL_PAIRS[inst.saddr]
else: saddr_s = _sreg(inst.saddr, 2) if prefix == 'global' else decode_src(inst.saddr)
# Address width: 1 for scratch with saddr, 2 otherwise
addr_w = 1 if (prefix == 'scratch' or (inst.saddr != 124 and prefix != 'flat')) else 2
vaddr_s = _vreg(inst.vaddr, addr_w)
vsrc_s = _vreg(inst.vsrc, vsrc_w)
vdst_s = _vreg(inst.vdst, vdst_w)
# addtid instructions don't have vaddr, just vdata and saddr
if is_addtid:
if is_store: return f"{name} {vsrc_s}, {saddr_s}{off_s}" + (f" {mods}" if mods else "")
return f"{name} {vdst_s}, {saddr_s}{off_s}" + (f" {mods}" if mods else "")
# Regular instructions need comma before saddr
saddr_s = f", {saddr_s}" if saddr_s else ""
if is_atomic:
if inst.th == 1: # TH_ATOMIC_RETURN
return f"{name} {vdst_s}, {vaddr_s}, {vsrc_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "")
return f"{name} {vaddr_s}, {vsrc_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "")
if is_store: return f"{name} {vaddr_s}, {vsrc_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "")
return f"{name} {vdst_s}, {vaddr_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "")
def _disasm_vbuffer(inst) -> str:
"""Disassemble RDNA4 VBUFFER instructions."""
from extra.assembly.amd.autogen.rdna4.enum import VBUFFEROp
name = VBUFFEROp(inst.op).name.lower()
# Determine if this is a typed buffer instruction (MTBUF format)
is_format = 'format' in name
# Data width based on instruction name
if is_format:
w = {'x': 1, 'xy': 2, 'xyz': 3, 'xyzw': 4}.get(name.split('_')[-1], 1)
if 'd16' in name: w = (w + 1) // 2
else:
suffix = name.split('_')[-1]
w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'b8':1,'b16':1,'f16':1,'f32':1,'bf16':1}.get(suffix, 1)
if 'cmpswap' in name: w *= 2
if inst.tfe: w += 1
vdata_s = _vreg(inst.vdata, w)
vaddr_s = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else (f"v{inst.vaddr}" if inst.offen or inst.idxen else "off")
# RDNA4 VBUFFER rsrc field stores the SGPR index directly (not /4 like RDNA3)
srsrc_s = _sreg_or_ttmp(inst.rsrc, 4)
soffset_s = decode_src(inst.soffset)
off = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000
is_store, is_atomic = 'store' in name, 'atomic' in name
mods = _rdna4_mem_mods(inst.th, inst.scope, is_store, is_atomic)
# Format field is only for MTBUF (tbuffer_*) instructions, not buffer_*_format_* instructions
# We don't output format for buffer instructions since they use implicit format
parts = []
if inst.idxen: parts.append("idxen")
if inst.offen: parts.append("offen")
if off: parts.append(f"offset:{off}")
if mods: parts.append(mods)
if inst.tfe: parts.append("tfe")
return f"{name} {vdata_s}, {vaddr_s}, {srsrc_s}, {soffset_s}" + (f" {' '.join(parts)}" if parts else "")
def _disasm_vimage(inst) -> str:
"""Disassemble RDNA4 VIMAGE instructions."""
from extra.assembly.amd.autogen.rdna4.enum import VIMAGEOp
name = VIMAGEOp(inst.op).name.lower()
# RDNA4 VIMAGE rsrc field stores the SGPR index directly (not /4 like RDNA3)
if 'bvh' in name:
# BVH intersect ray: special format with individual/range vaddr components
# Format: [node_ptr, ray_extent, ray_origin(3), ray_dir(3), ray_inv_dir(3)]
# bvh64 has 2-VGPR node_ptr, a16 removes ray_inv_dir
if 'dual' in name or 'bvh8' in name:
# dual/bvh8: [node_ptr(2), ray_extent(2), ray_origin(3), ray_dir(3), ...]
parts = [_vreg(inst.vaddr0, 2), _vreg(inst.vaddr1, 2), _vreg(inst.vaddr2, 3), _vreg(inst.vaddr3, 3)]
if not inst.a16: parts.append(_vreg(inst.vaddr4, 1 if 'bvh8' in name else 2))
dst_w = 10
elif '64' in name:
parts = [_vreg(inst.vaddr0, 2), f"v{inst.vaddr1}", _vreg(inst.vaddr2, 3), _vreg(inst.vaddr3, 3)]
if not inst.a16: parts.append(_vreg(inst.vaddr4, 3))
dst_w = 4
else:
parts = [f"v{inst.vaddr0}", f"v{inst.vaddr1}", _vreg(inst.vaddr2, 3), _vreg(inst.vaddr3, 3)]
if not inst.a16: parts.append(_vreg(inst.vaddr4, 3))
dst_w = 4
return f"{name} {_vreg(inst.vdata, dst_w)}, [{', '.join(parts)}], {_sreg_or_ttmp(inst.rsrc, 4)}{' a16' if inst.a16 else ''}"
# vdata width - msaa_load always uses 4 VGPRs per channel
vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1)
if inst.d16: vdata = (vdata + 1) // 2
if inst.tfe: vdata += 1
# dim names
dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"{inst.dim}"
# vaddr width calculation (RDNA4 uses vaddr0-4 for address components)
vaddr_w = _mimg_vaddr_width(name, inst.dim, inst.a16)
if vaddr_w == 1: vaddr_s = f"v{inst.vaddr0}"
elif vaddr_w == 2: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}]"
elif vaddr_w == 3: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}]"
elif vaddr_w == 4: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, v{inst.vaddr3}]"
else: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, v{inst.vaddr3}, v{inst.vaddr4}]"
srsrc_s = _sreg_or_ttmp(inst.rsrc, 8)
# RDNA4 always requires dmask for size calculation (even if 0xf)
mods = [f"dmask:0x{inst.dmask:x}"]
mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}")
# Add th/scope before other modifiers, then r128, then a16/tfe/d16 (LLVM expects this order)
if inst.th or inst.scope:
is_store, is_atomic = 'store' in name, 'atomic' in name
mem_mods = _rdna4_mem_mods(inst.th, inst.scope, is_store, is_atomic)
if mem_mods: mods.append(mem_mods)
if inst.r128: mods.append("r128")
mods.extend([m for c, m in [(inst.a16, "a16"), (inst.tfe, "tfe"), (inst.d16, "d16")] if c])
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_s}, {srsrc_s} {' '.join(mods)}"
def _disasm_vsample(inst) -> str:
"""Disassemble RDNA4 VSAMPLE instructions."""
from extra.assembly.amd.autogen.rdna4.enum import VSAMPLEOp
name = VSAMPLEOp(inst.op).name.lower()
# vdata width
vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1)
if inst.d16: vdata = (vdata + 1) // 2
if inst.tfe: vdata += 1
dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"{inst.dim}"
vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16)
if vaddr == 1: vaddr_s = f"v{inst.vaddr0}"
elif vaddr == 2: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}]"
elif vaddr == 3: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}]"
elif vaddr == 4: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, v{inst.vaddr3}]"
else:
# More than 4 vaddrs: vaddr3 becomes start of a contiguous range for the remaining coords
extra = vaddr - 3 # vaddr0-2 are individual, vaddr3 starts range of remaining
vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, {_vreg(inst.vaddr3, extra)}]"
# RDNA4 VSAMPLE rsrc/samp fields store the SGPR index directly (not /4 like RDNA3)
srsrc_s = _sreg_or_ttmp(inst.rsrc, 8)
# msaa_load doesn't use a sampler (it's a load, not a sample), but gather4h does
uses_sampler = 'msaa_load' not in name
ssamp_s = f", {_sreg_or_ttmp(inst.samp, 4)}" if uses_sampler else ""
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask else []
mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}")
if inst.unrm: mods.append("unorm")
# th/scope must come before r128, a16, tfe, lwe, d16
if inst.th or inst.scope:
mem_mods = _rdna4_mem_mods(inst.th, inst.scope, False, False)
if mem_mods: mods.append(mem_mods)
mods.extend([m for c, m in [(inst.r128, "r128"), (inst.a16, "a16"), (inst.tfe, "tfe"), (inst.lwe, "lwe"), (inst.d16, "d16")] if c])
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_s}, {srsrc_s}{ssamp_s} {' '.join(mods)}"
DISASM_HANDLERS = {VOP1: _disasm_vop1, VOP2: _disasm_vop2, VOPC: _disasm_vopc, VOP3: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOPD: _disasm_vopd, VOP3P: _disasm_vop3p,
VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, MUBUF: _disasm_buf, MTBUF: _disasm_buf,
MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk, EXP: _disasm_exp, LDSDIR: _disasm_ldsdir}
# RDNA4 uses different class names, dispatch by name for cross-arch support
_DISASM_BY_NAME = {'VOP1': _disasm_vop1, 'VOP2': _disasm_vop2, 'VOPC': _disasm_vopc, 'VOP3': _disasm_vop3, 'VOP3SD': _disasm_vop3sd,
'VOPD': _disasm_vopd, 'VOP3P': _disasm_vop3p, 'VINTERP': _disasm_vinterp, 'SOPP': _disasm_sopp, 'SMEM': _disasm_smem,
'DS': _disasm_ds, 'VDS': _disasm_ds, 'FLAT': _disasm_flat, 'MUBUF': _disasm_buf, 'MTBUF': _disasm_buf, 'MIMG': _disasm_mimg,
'SOP1': _disasm_sop1, 'SOP2': _disasm_sop2, 'SOPC': _disasm_sopc, 'SOPK': _disasm_sopk,
'EXP': _disasm_exp, 'LDSDIR': _disasm_ldsdir, 'VEXPORT': _disasm_exp, 'VDSDIR': _disasm_ldsdir,
'VFLAT': _disasm_vflat, 'VGLOBAL': _disasm_vflat, 'VSCRATCH': _disasm_vflat,
'VBUFFER': _disasm_vbuffer, 'VIMAGE': _disasm_vimage, 'VSAMPLE': _disasm_vsample}
def disasm(inst: Inst, wave_size: int = 32) -> str:
handler = DISASM_HANDLERS.get(type(inst)) or _DISASM_BY_NAME.get(type(inst).__name__)
if handler is None: raise KeyError(f"no disasm handler for {type(inst).__name__}")
# For VOP3P (includes WMMA), pass wave_size if handler supports it
if handler == _disasm_vop3p: return _disasm_vop3p(inst, wave_size)
return handler(inst)
# ═══════════════════════════════════════════════════════════════════════════════
# ASSEMBLER
# ═══════════════════════════════════════════════════════════════════════════════
SPEC_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), 'null': RawImm(124), 'off': RawImm(124), 'm0': RawImm(125),
'exec_lo': RawImm(126), 'exec_hi': RawImm(127), 'exec': RawImm(126), 'scc': RawImm(253), 'src_scc': RawImm(253)}
FLOATS = {str(k): k for k in FLOAT_ENC} # Valid float literal strings: '0.5', '-0.5', '1.0', etc.
REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp}
SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b128', 's_load_b256', 's_load_b512',
's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512'}
SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', 'm0': 'M0',
'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', 'src_scc': 'SCC'}
def _op2dsl(op: str) -> str:
op = op.strip()
neg = op.startswith('-') and not (op[1:2].isdigit() or (len(op) > 2 and op[1] == '0' and op[2] in 'xX'))
if neg: op = op[1:]
abs_ = (op.startswith('|') and op.endswith('|')) or (op.startswith('abs(') and op.endswith(')'))
if abs_: op = op[1:-1] if op.startswith('|') else op[4:-1]
hi = ".h" if op.endswith('.h') else ".l" if op.endswith('.l') else ""
if hi: op = op[:-2]
lo = op.lower()
def wrap(b): return f"{'-' if neg else ''}abs({b}){hi}" if abs_ else f"-{b}{hi}" if neg else f"{b}{hi}"
if lo in SPEC_DSL: return wrap(SPEC_DSL[lo])
if op in FLOATS: return wrap(op)
rp = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'}
if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return wrap(f"{rp[m.group(1)]}[{m.group(2)}:{m.group(3)}]")
if m := re.match(r'^([svt](?:tmp)?)(\d+)$', lo): return wrap(f"{rp[m.group(1)]}[{m.group(2)}]")
if re.match(r'^-?\d+$|^-?0x[0-9a-fA-F]+$', op): return f"SrcMod({op}, neg={neg}, abs_={abs_})" if neg or abs_ else op
return wrap(op)
def _parse_ops(s: str) -> list[str]:
ops, cur, depth, pipe = [], "", 0, False
for c in s:
if c in '[(': depth += 1
elif c in '])': depth -= 1
elif c == '|': pipe = not pipe
if c == ',' and depth == 0 and not pipe: ops.append(cur.strip()); cur = ""
else: cur += c
if cur.strip(): ops.append(cur.strip())
return ops
def _extract(text: str, pat: str, flags=re.I):
if m := re.search(pat, text, flags): return m, text[:m.start()] + text[m.end():]
return None, text
def get_dsl(text: str) -> str:
text, kw = text.strip(), []
# Extract modifiers
for pat, val in [(r'\s+mul:2(?:\s|$)', 1), (r'\s+mul:4(?:\s|$)', 2), (r'\s+div:2(?:\s|$)', 3)]:
if (m := _extract(text, pat))[0]: kw.append(f'omod={val}'); text = m[1]; break
if (m := _extract(text, r'\s+clamp(?:\s|$)'))[0]: kw.append('clmp=1'); text = m[1]
opsel, m, text = None, *_extract(text, r'\s+op_sel:\[([^\]]+)\]')
if m:
bits, mn = [int(x.strip()) for x in m.group(1).split(',')], text.split()[0].lower()
is3p = mn.startswith(('v_pk_', 'v_wmma_', 'v_dot'))
opsel = (bits[0] | (bits[1] << 1) | (bits[2] << 2)) if len(bits) == 3 and is3p else \
(bits[0] | (bits[1] << 1) | (bits[2] << 3)) if len(bits) == 3 else sum(b << i for i, b in enumerate(bits))
m, text = _extract(text, r'\s+wait_exp:(\d+)'); waitexp = m.group(1) if m else None
m, text = _extract(text, r'\s+offset:(0x[0-9a-fA-F]+|-?\d+)'); off_val = m.group(1) if m else None
m, text = _extract(text, r'\s+dlc(?:\s|$)'); dlc = 1 if m else None
m, text = _extract(text, r'\s+glc(?:\s|$)'); glc = 1 if m else None
m, text = _extract(text, r'\s+slc(?:\s|$)'); slc = 1 if m else None
m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
if waitexp: kw.append(f'waitexp={waitexp}')
parts = text.replace(',', ' ').split()
if not parts: raise ValueError("empty instruction")
mn, op_str = parts[0].lower(), text[len(parts[0]):].strip()
ops, args = _parse_ops(op_str), [_op2dsl(o) for o in _parse_ops(op_str)]
# s_waitcnt
if mn == 's_waitcnt':
vm, exp, lgkm = 0x3f, 0x7, 0x3f
for p in op_str.replace(',', ' ').split():
if m := re.match(r'vmcnt\((\d+)\)', p): vm = int(m.group(1))
elif m := re.match(r'expcnt\((\d+)\)', p): exp = int(m.group(1))
elif m := re.match(r'lgkmcnt\((\d+)\)', p): lgkm = int(m.group(1))
elif re.match(r'^0x[0-9a-f]+$|^\d+$', p): return f"s_waitcnt(simm16={int(p, 0)})"
return f"s_waitcnt(simm16={waitcnt(vm, exp, lgkm)})"
# VOPD
if '::' in text:
xp, yp = text.split('::')
xps, yps = xp.strip().replace(',', ' ').split(), yp.strip().replace(',', ' ').split()
xo, yo = [_op2dsl(p) for p in xps[1:]], [_op2dsl(p) for p in yps[1:]]
vdx, sx0, vsx1 = xo[0], xo[1] if len(xo) > 1 else '0', xo[2] if len(xo) > 2 else 'v[0]'
vdy, sy0, vsy1 = yo[0], yo[1] if len(yo) > 1 else '0', yo[2] if len(yo) > 2 else 'v[0]'
lit = xo[3] if 'fmaak' in xps[0].lower() and len(xo) > 3 else yo[3] if 'fmaak' in yps[0].lower() and len(yo) > 3 else None
if 'fmamk' in xps[0].lower() and len(xo) > 3: lit, vsx1 = xo[2], xo[3]
elif 'fmamk' in yps[0].lower() and len(yo) > 3: lit, vsy1 = yo[2], yo[3]
return f"VOPD(VOPDOp.{xps[0].upper()}, VOPDOp.{yps[0].upper()}, vdstx={vdx}, vdsty={vdy}, srcx0={sx0}, vsrcx1={vsx1}, srcy0={sy0}, vsrcy1={vsy1}{f', literal={lit}' if lit else ''})"
# Special instructions
if mn == 's_setreg_imm32_b32': raise ValueError(f"unsupported: {mn}")
if mn in ('s_setpc_b64', 's_rfe_b64'): return f"{mn}(ssrc0={args[0]})"
if mn in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'): return f"{mn}(sdst={args[0]}, ssrc0=RawImm({args[1].strip()}))"
if mn == 's_version': return f"{mn}(simm16={args[0]})"
if mn == 's_setreg_b32': return f"{mn}(simm16={args[0]}, sdst={args[1]})"
# SMEM
if mn in SMEM_OPS:
gs, ds = ", glc=1" if glc else "", ", dlc=1" if dlc else ""
if len(ops) >= 3 and re.match(r'^-?[0-9]|^-?0x', ops[2].strip().lower()):
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(124){gs}{ds})"
if off_val and len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}{gs}{ds})"
if len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, soffset={args[2]}{gs}{ds})"
# Buffer
if mn.startswith('buffer_') and len(ops) >= 2 and ops[1].strip().lower() == 'off':
return f"{mn}(vdata={args[0]}, vaddr=0, srsrc={args[2]}, soffset={f'RawImm({args[3].strip()})' if len(args) > 3 else 'RawImm(0)'})"
# FLAT/GLOBAL/SCRATCH load/store/atomic - saddr needs RawImm(124) for off/null
def _saddr(a): return 'RawImm(124)' if a in ('OFF', 'NULL') else a
flat_mods = f"{f', offset={off_val}' if off_val else ''}{', glc=1' if glc else ''}{', slc=1' if slc else ''}{', dlc=1' if dlc else ''}"
for pre, flds in [('flat_load','vdst,addr,saddr'), ('global_load','vdst,addr,saddr'), ('scratch_load','vdst,addr,saddr'),
('flat_store','addr,data,saddr'), ('global_store','addr,data,saddr'), ('scratch_store','addr,data,saddr')]:
if mn.startswith(pre) and len(args) >= 2:
f0, f1, f2 = flds.split(',')
return f"{mn}({f0}={args[0]}, {f1}={args[1]}{f', {f2}={_saddr(args[2])}' if len(args) >= 3 else ', saddr=RawImm(124)'}{flat_mods})"
for pre in ('flat_atomic', 'global_atomic', 'scratch_atomic'):
if mn.startswith(pre):
if glc and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3])}' if len(args) >= 4 else ', saddr=RawImm(124)'}{flat_mods})"
if len(args) >= 2: return f"{mn}(addr={args[0]}, data={args[1]}{f', saddr={_saddr(args[2])}' if len(args) >= 3 else ', saddr=RawImm(124)'}{flat_mods})"
# DS instructions
if mn.startswith('ds_'):
off0, off1 = (str(int(off_val, 0) & 0xff), str((int(off_val, 0) >> 8) & 0xff)) if off_val else ("0", "0")
gds_s = ", gds=1" if 'gds' in text.lower().split()[-1:] else ""
off_kw = f", offset0={off0}, offset1={off1}{gds_s}"
if mn == 'ds_nop' or mn in ('ds_gws_sema_v', 'ds_gws_sema_p', 'ds_gws_sema_release_all'): return f"{mn}({off_kw.lstrip(', ')})"
if 'gws_' in mn: return f"{mn}(addr={args[0]}{off_kw})"
if 'consume' in mn or 'append' in mn: return f"{mn}(vdst={args[0]}{off_kw})"
if 'gs_reg' in mn: return f"{mn}(vdst={args[0]}, data0={args[1]}{off_kw})"
if '2addr' in mn:
if 'load' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
if 'store' in mn and 'xchg' not in mn: return f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})"
if 'load' in mn: return f"{mn}(vdst={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
if 'store' in mn and not _has(mn, 'cmp', 'xchg'):
return f"{mn}(data0={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})"
if 'swizzle' in mn or 'ordered_count' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
if 'permute' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})"
if 'bvh' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})"
if 'condxchg' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})"
if _has(mn, 'cmpstore', 'mskor', 'wrap'):
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})"
# v_fmaak/v_fmamk literal extraction
lit_s = ""
if mn in ('v_fmaak_f32', 'v_fmaak_f16') and len(args) == 4: lit_s, args = f", literal={args[3].strip()}", args[:3]
elif mn in ('v_fmamk_f32', 'v_fmamk_f16') and len(args) == 4: lit_s, args = f", literal={args[2].strip()}", [args[0], args[1], args[3]]
# VCC ops cleanup
vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'}
if mn.replace('_e32', '') in vcc_ops and len(args) >= 5: mn, args = mn.replace('_e32', '') + '_e32', [args[0], args[2], args[3]]
if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '')
if mn.startswith('v_cmp') and not mn.endswith('_e64') and len(args) >= 3 and ops[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): args = args[1:]
if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2: args = ['RawImm(126)'] + args
fn = mn.replace('.', '_')
if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args]
# v_fma_mix*: extract inline neg/abs modifiers
if 'fma_mix' in mn and neg_lo is None and neg_hi is None:
inline_neg, inline_abs, clean_args = 0, 0, [args[0]]
for i, op in enumerate(ops[1:4]):
op = op.strip()
neg = op.startswith('-') and not (op[1:2].isdigit() or (len(op) > 2 and op[1] == '0' and op[2] in 'xX'))
if neg: op = op[1:]
abs_ = op.startswith('|') and op.endswith('|')
if abs_: op = op[1:-1]
if neg: inline_neg |= (1 << i)
if abs_: inline_abs |= (1 << i)
clean_args.append(_op2dsl(op))
args = clean_args + args[4:]
if inline_neg: neg_lo = inline_neg
if inline_abs: neg_hi = inline_abs
all_kw = list(kw)
if lit_s: all_kw.append(lit_s.lstrip(', '))
if opsel is not None: all_kw.append(f'opsel={opsel}')
if neg_lo is not None: all_kw.append(f'neg={neg_lo}')
if neg_hi is not None: all_kw.append(f'neg_hi={neg_hi}')
if 'bvh' in mn and 'intersect_ray' in mn: all_kw.extend(['dmask=15', 'unrm=1', 'r128=1'])
a_str, kw_str = ', '.join(args), ', '.join(all_kw)
return f"{fn}({a_str}, {kw_str})" if kw_str and a_str else f"{fn}({kw_str})" if kw_str else f"{fn}({a_str})"
def asm(text: str) -> Inst:
dsl = get_dsl(text)
ns = {n: getattr(ins, n) for n in dir(ins) if not n.startswith('_')}
ns.update({'s': s, 'v': v, 'ttmp': ttmp, 'abs': abs, 'RawImm': RawImm, 'SrcMod': SrcMod, 'VGPR': VGPR, 'SGPR': SGPR, 'TTMP': TTMP,
'VCC_LO': VCC_LO, 'VCC_HI': VCC_HI, 'VCC': VCC, 'EXEC_LO': EXEC_LO, 'EXEC_HI': EXEC_HI, 'EXEC': EXEC, 'SCC': SCC, 'M0': M0, 'NULL': NULL, 'OFF': OFF})
try: return eval(dsl, ns)
except NameError:
if m := re.match(r'^(v_\w+)(\(.*\))$', dsl): return eval(f"{m.group(1)}_e32{m.group(2)}", ns)
raise