mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
rdna4 works
This commit is contained in:
parent
2a6904029b
commit
9302f38f5b
4 changed files with 225 additions and 141 deletions
|
|
@ -1,4 +1,4 @@
|
|||
# RDNA3 assembler and disassembler
|
||||
# RDNA3/RDNA4 assembler and disassembler
|
||||
from __future__ import annotations
|
||||
import re
|
||||
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
|
||||
|
|
@ -9,26 +9,21 @@ from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3
|
|||
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp)
|
||||
|
||||
def _matches_encoding(word: int, cls: type[Inst]) -> bool:
|
||||
"""Check if word matches the encoding pattern of an instruction class."""
|
||||
if cls._encoding is None: return False
|
||||
bf, val = cls._encoding
|
||||
return ((word >> bf.lo) & bf.mask()) == val
|
||||
|
||||
# Order matters: more specific encodings first, VOP2 last (it's a catch-all for bit31=0)
|
||||
_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP]
|
||||
_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls
|
||||
_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2]
|
||||
|
||||
def detect_format(data: bytes) -> type[Inst]:
|
||||
"""Detect instruction format from machine code bytes."""
|
||||
assert len(data) >= 4, f"need at least 4 bytes, got {len(data)}"
|
||||
word = int.from_bytes(data[:4], 'little')
|
||||
# Check 64-bit formats first (bits[31:30] == 0b11)
|
||||
if (word >> 30) == 0b11:
|
||||
for cls in _FORMATS_64:
|
||||
if _matches_encoding(word, cls):
|
||||
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in Inst._VOP3SD_OPS else cls
|
||||
raise ValueError(f"unknown 64-bit format word={word:#010x}")
|
||||
# 32-bit formats
|
||||
for cls in _FORMATS_32:
|
||||
if _matches_encoding(word, cls): return cls
|
||||
raise ValueError(f"unknown 32-bit format word={word:#010x}")
|
||||
|
|
@ -41,10 +36,22 @@ HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_H
|
|||
6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS', 15: 'HW_REG_SH_MEM_BASES', 18: 'HW_REG_PERF_SNAPSHOT_PC_LO',
|
||||
19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK',
|
||||
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'}
|
||||
HWREG_RDNA4 = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 5: 'HW_REG_GPR_ALLOC', 6: 'HW_REG_LDS_ALLOC',
|
||||
7: 'HW_REG_IB_STS', 23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2'}
|
||||
HWREG_IDS = {v.lower(): k for k, v in HWREG.items()}
|
||||
MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA',
|
||||
131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'}
|
||||
|
||||
# RDNA4 cache policy tables
|
||||
_TH_LOAD = {0: None, 1: 'TH_LOAD_NT', 2: 'TH_LOAD_HT', 3: 'TH_LOAD_LU', 4: 'TH_LOAD_NT_RT', 5: 'TH_LOAD_RT_NT', 6: 'TH_LOAD_NT_HT'}
|
||||
_TH_STORE = {0: None, 1: 'TH_STORE_NT', 2: 'TH_STORE_HT', 3: 'TH_STORE_LU', 4: 'TH_STORE_NT_RT', 5: 'TH_STORE_RT_NT', 6: 'TH_STORE_NT_HT'}
|
||||
_TH_ATOMIC = {0: None, 1: 'TH_ATOMIC_NT', 2: 'TH_ATOMIC_RETURN'}
|
||||
_SCOPE = {0: None, 1: 'SCOPE_SE', 2: 'SCOPE_SA', 3: 'SCOPE_SYS'}
|
||||
|
||||
# Export target mapping
|
||||
_EXP_TARGETS = {**{i: f'mrt{i}' for i in range(8)}, 8: 'mrtz', **{i+12: f'pos{i}' for i in range(5)},
|
||||
20: 'prim', 21: 'dual_src_blend0', 22: 'dual_src_blend1'}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# HELPERS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
|
@ -77,12 +84,11 @@ def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int:
|
|||
|
||||
def _has(op: str, *subs) -> bool: return any(s in op for s in subs)
|
||||
def _omod(v: int) -> str: return {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(v, "")
|
||||
def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v) # format 16-bit src: vgpr.h/l or literal
|
||||
def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v)
|
||||
def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c)
|
||||
def _fmt_bits(label: str, val: int, count: int) -> str: return f"{label}:[{','.join(str((val >> i) & 1) for i in range(count))}]"
|
||||
|
||||
def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any_hi: bool) -> str:
|
||||
"""Format VOP3 source operand with modifiers."""
|
||||
if n > 1: s = _fmt_src(v, n)
|
||||
elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else (f"v{v - 256}.l" if any_hi else inst.lit(v))
|
||||
else: s = inst.lit(v)
|
||||
|
|
@ -90,12 +96,37 @@ def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any
|
|||
return f"-{s}" if neg else s
|
||||
|
||||
def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
|
||||
"""Format op_sel modifier string."""
|
||||
if not need: return ""
|
||||
if is16_d and (opsel & 8): return f" op_sel:[1,1,1{',1' if n == 3 else ''}]"
|
||||
if n == 3: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{(opsel >> 3) & 1}]"
|
||||
return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1}]"
|
||||
|
||||
def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
|
||||
base = [1, 2, 3, 3, 2, 3, 3, 4][dim] # 1d,2d,3d,cube,1d_arr,2d_arr,2d_msaa,2d_msaa_arr
|
||||
grad = [1, 2, 3, 2, 1, 2, 2, 2][dim]
|
||||
if 'get_resinfo' in name: return 1
|
||||
packed, unpacked = 0, 0
|
||||
if '_mip' in name: packed += 1
|
||||
elif 'sample' in name or 'gather' in name:
|
||||
if '_o' in name: unpacked += 1
|
||||
if re.search(r'_c(_|$)', name): unpacked += 1
|
||||
if '_d' in name: unpacked += (grad + 1) & ~1 if '_g16' in name else grad*2
|
||||
if '_b' in name: unpacked += 1
|
||||
if '_l' in name and '_cl' not in name and '_lz' not in name: packed += 1
|
||||
if '_cl' in name: packed += 1
|
||||
return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked
|
||||
|
||||
def _collect_vaddrs(inst, count: int) -> list[int]:
|
||||
vaddrs = [inst.vaddr0]
|
||||
if count > 1: vaddrs.append(inst.vaddr1)
|
||||
if count > 2: vaddrs.append(inst.vaddr2)
|
||||
if count > 3: vaddrs.append(inst.vaddr3)
|
||||
if count > 4 and hasattr(inst, 'vaddr4'): vaddrs.append(inst.vaddr4)
|
||||
return vaddrs[:count]
|
||||
|
||||
def _fmt_vaddr_nsa(vaddrs: list[int]) -> str:
|
||||
return f"v{vaddrs[0]}" if len(vaddrs) == 1 else "[" + ", ".join(f"v{v}" for v in vaddrs) + "]"
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# DISASSEMBLER
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
|
@ -104,7 +135,6 @@ def _disasm_vop1(inst: VOP1) -> str:
|
|||
name = inst.op_name.lower()
|
||||
if inst.op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name
|
||||
if inst.op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}"
|
||||
# 16-bit dst: uses .h/.l suffix (determined by name pattern, not dtype - e.g. sat_pk_u8_i16 outputs 8-bit but uses 16-bit encoding)
|
||||
parts = name.split('_')
|
||||
is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
|
||||
|
|
@ -113,11 +143,9 @@ def _disasm_vop1(inst: VOP1) -> str:
|
|||
|
||||
def _disasm_vop2(inst: VOP2) -> str:
|
||||
name = inst.op_name.lower()
|
||||
# RDNA4 removed V_DOT2ACC_F32_F16, check if op exists
|
||||
try: is_dot2acc = inst.op == VOP2Op.V_DOT2ACC_F32_F16
|
||||
except ValueError: is_dot2acc = False
|
||||
suf = "" if is_dot2acc else "_e32"
|
||||
# fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1
|
||||
try:
|
||||
if inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}"
|
||||
if inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}"
|
||||
|
|
@ -126,7 +154,6 @@ def _disasm_vop2(inst: VOP2) -> str:
|
|||
if inst.op == VOP2Op.V_CNDMASK_B32: return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, vcc_lo"
|
||||
except ValueError: pass
|
||||
if inst.is_16bit(): return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
|
||||
# Handle 64-bit VOP2 instructions (RDNA4)
|
||||
dn, sn0, sn1 = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1)
|
||||
dst = _vreg(inst.vdst, dn) if dn > 1 else f"v{inst.vdst}"
|
||||
src0 = _fmt_src(inst.src0, sn0) if sn0 > 1 else inst.lit(inst.src0)
|
||||
|
|
@ -144,13 +171,15 @@ NO_ARG_SOPP = {SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICAC
|
|||
|
||||
def _disasm_sopp(inst: SOPP) -> str:
|
||||
name = inst.op_name.lower()
|
||||
if not name: raise ValueError(f"undefined SOPP op: {inst.op}")
|
||||
if inst.op in NO_ARG_SOPP: return name
|
||||
if inst.op == SOPPOp.S_WAITCNT:
|
||||
vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f
|
||||
p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
|
||||
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
|
||||
if inst.op == SOPPOp.S_DELAY_ALU:
|
||||
deps, skips = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3'], ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4']
|
||||
deps = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3']
|
||||
skips = ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4']
|
||||
id0, skip, id1 = inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x7, (inst.simm16 >> 7) & 0xf
|
||||
dep = lambda v: deps[v-1] if 0 < v <= len(deps) else str(v)
|
||||
p = [f"instid0({dep(id0)})" if id0 else "", f"instskip({skips[skip]})" if skip else "", f"instid1({dep(id1)})" if id1 else ""]
|
||||
|
|
@ -159,6 +188,7 @@ def _disasm_sopp(inst: SOPP) -> str:
|
|||
|
||||
def _disasm_smem(inst: SMEM) -> str:
|
||||
name = inst.op_name.lower()
|
||||
if 'rdna4' in inst.__class__.__module__: return _disasm_smem_rdna4(inst)
|
||||
if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
|
||||
off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset)
|
||||
sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op.value <= 12 or name == 's_atc_probe_buffer') else 2
|
||||
|
|
@ -166,6 +196,33 @@ def _disasm_smem(inst: SMEM) -> str:
|
|||
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}"
|
||||
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
|
||||
|
||||
def _disasm_smem_rdna4(inst) -> str:
|
||||
name = inst.op_name.lower()
|
||||
op_val = inst._values.get('op')
|
||||
if not name:
|
||||
name = {34: 's_atc_probe', 35: 's_atc_probe_buffer', 32: 's_gl1_inv'}.get(op_val, f's_smem_op{op_val}')
|
||||
if name in ('s_gl1_inv', 's_dcache_inv'): return name
|
||||
sbase_idx, sbase_count = inst.sbase * 2, 4 if 'buffer' in name else 2
|
||||
if sbase_idx == 106: sbase_str = "vcc"
|
||||
elif 108 <= sbase_idx <= 123: sbase_str = _reg("ttmp", sbase_idx - 108, sbase_count)
|
||||
else: sbase_str = _sreg(sbase_idx, sbase_count)
|
||||
ioffset = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000
|
||||
off_str = f"0x{ioffset:x}" if ioffset >= 0 else f"-0x{-ioffset:x}"
|
||||
soffset_str = decode_src(inst.soffset)
|
||||
th_names = ['','TH_LOAD_NT','TH_LOAD_HT','TH_LOAD_LU','TH_LOAD_NT_RT','TH_LOAD_NT_HT','TH_LOAD_BYPASS']
|
||||
scope_names = ['','SCOPE_SE','SCOPE_DEV','SCOPE_SYS']
|
||||
if 'prefetch' in name:
|
||||
return f"{name} {off_str}, {soffset_str}, {inst.sdata}" if 'pc_rel' in name else f"{name} {sbase_str}, {off_str}, {soffset_str}, {inst.sdata}"
|
||||
if 'atc_probe' in name:
|
||||
return f"{name} {inst.sdata}, {sbase_str}, {soffset_str}" + (f" offset:{off_str}" if ioffset else "")
|
||||
if inst.soffset == 124: base_str = f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_str}"
|
||||
elif ioffset: base_str = f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {soffset_str} offset:{off_str}"
|
||||
else: base_str = f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {soffset_str}"
|
||||
mods = []
|
||||
if inst.th and inst.th < len(th_names) and th_names[inst.th]: mods.append(f"th:{th_names[inst.th]}")
|
||||
if inst.scope and inst.scope < len(scope_names) and scope_names[inst.scope]: mods.append(f"scope:{scope_names[inst.scope]}")
|
||||
return base_str + (" " + " ".join(mods) if mods else "")
|
||||
|
||||
def _disasm_flat(inst: FLAT) -> str:
|
||||
name = inst.op_name.lower()
|
||||
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
|
||||
|
|
@ -173,16 +230,13 @@ def _disasm_flat(inst: FLAT) -> str:
|
|||
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
|
||||
w = inst.dst_regs() * (2 if 'cmpswap' in name else 1)
|
||||
mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
|
||||
# saddr
|
||||
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
|
||||
elif inst.saddr == 124: saddr_s = ", off"
|
||||
elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr)}"
|
||||
elif inst.saddr in SPECIAL_PAIRS: saddr_s = f", {SPECIAL_PAIRS[inst.saddr]}"
|
||||
elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}"
|
||||
else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr)}"
|
||||
# addtid: no addr
|
||||
if 'addtid' in name: return f"{instr} v{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}"
|
||||
# addr width
|
||||
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2)
|
||||
data_s, vdst_s = _vreg(inst.data, w), _vreg(inst.vdst, w // 2 if 'cmpswap' in name else w)
|
||||
if 'atomic' in name:
|
||||
|
|
@ -197,9 +251,11 @@ def _disasm_ds(inst: DS) -> str:
|
|||
off2 = f" offset0:{inst.offset0} offset1:{inst.offset1}" if inst.offset0 or inst.offset1 else ""
|
||||
w = inst.dst_regs()
|
||||
d0, d1, dst, addr = _vreg(inst.data0, w), _vreg(inst.data1, w), _vreg(inst.vdst, w), f"v{inst.addr}"
|
||||
|
||||
if op == DSOp.DS_NOP: return name
|
||||
if op == DSOp.DS_BVH_STACK_RTN_B32: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}"
|
||||
if 'bvh_stack_push4_pop1' in name: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}"
|
||||
if 'bvh_stack_push8_pop1' in name: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 8)}{off}{gds}"
|
||||
if 'bvh_stack_push8_pop2' in name: return f"{name} {_vreg(inst.vdst, 2)}, {addr}, v{inst.data0}, {_vreg(inst.data1, 8)}{off}{gds}"
|
||||
if 'gws_sema' in name and op != DSOp.DS_GWS_SEMA_BR: return f"{name}{off}{gds}"
|
||||
if 'gws_' in name: return f"{name} {addr}{off}{gds}"
|
||||
if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} v{inst.vdst}{off}{gds}"
|
||||
|
|
@ -220,17 +276,15 @@ def _disasm_ds(inst: DS) -> str:
|
|||
|
||||
def _disasm_vop3(inst: VOP3) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
|
||||
# VOP3SD (shared encoding)
|
||||
if isinstance(op, VOP3SDOp):
|
||||
if name.startswith('v_s_'):
|
||||
return f"{name} {_fmt_sdst(inst.vdst, 1)}, {_fmt_src(inst.src0, 1)}"
|
||||
if hasattr(op, '__class__') and op.__class__.__name__ == 'VOP3SDOp':
|
||||
sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs
|
||||
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
|
||||
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
|
||||
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
|
||||
return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {srcs}" + _omod(inst.omod)
|
||||
|
||||
# Detect 16-bit operand sizes (for .h/.l suffix handling)
|
||||
is16_d = is16_s = is16_s2 = False
|
||||
if 'cvt_pk' in name: is16_s = name.endswith('16')
|
||||
elif m := re.match(r'v_(?:cvt|frexp_exp)_([a-z0-9_]+)_([a-z0-9]+)', name):
|
||||
|
|
@ -239,33 +293,30 @@ def _disasm_vop3(inst: VOP3) -> str:
|
|||
elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True
|
||||
elif 'pack_b32' in name: is16_s = is16_s2 = True
|
||||
else: is16_d = is16_s = is16_s2 = inst.is_16bit()
|
||||
|
||||
any_hi = inst.opsel != 0
|
||||
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s, any_hi)
|
||||
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s, any_hi)
|
||||
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2, any_hi)
|
||||
|
||||
# Destination
|
||||
dn = inst.dst_regs()
|
||||
if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1)
|
||||
elif dn > 1: dst = _vreg(inst.vdst, dn)
|
||||
elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" if any_hi else f"v{inst.vdst}"
|
||||
else: dst = f"v{inst.vdst}"
|
||||
|
||||
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
|
||||
nonvgpr_opsel = (inst.src0 < 256 and (inst.opsel & 1)) or (inst.src1 < 256 and (inst.opsel & 2)) or (inst.src2 < 256 and (inst.opsel & 4))
|
||||
need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s)
|
||||
|
||||
if inst.op < 256: # VOPC
|
||||
if inst.op < 256:
|
||||
return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
|
||||
if inst.op < 384: # VOP2
|
||||
if inst.op < 384:
|
||||
n = inst.num_srcs()
|
||||
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
|
||||
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}"
|
||||
if inst.op < 512: # VOP1
|
||||
if inst.op < 512:
|
||||
if _has(name, 'cvt_f32_fp8', 'cvt_f32_bf8'): need_opsel = False
|
||||
return f"{name}_e64" if op in (VOP3Op.V_NOP, VOP3Op.V_PIPEFLUSH) else f"{name}_e64 {dst}, {s0}{_opsel_str(inst.opsel, 1, need_opsel, is16_d)}{cl}{om}"
|
||||
# Native VOP3
|
||||
n = inst.num_srcs()
|
||||
if 'permlane' in name and '_var' in name: n = 2
|
||||
if _has(name, 'cvt_sr_fp8', 'cvt_sr_bf8'): n, need_opsel = 2, False
|
||||
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
|
||||
return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}"
|
||||
|
||||
|
|
@ -280,22 +331,49 @@ def _disasm_vop3sd(inst: VOP3SD) -> str:
|
|||
|
||||
def _disasm_vopd(inst: VOPD) -> str:
|
||||
lit = inst._literal or inst.literal
|
||||
vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), VOPDOp(inst.opx).name.lower(), VOPDOp(inst.opy).name.lower()
|
||||
if 'rdna4' in inst.__class__.__module__:
|
||||
import importlib
|
||||
VOPDOpCls = importlib.import_module('extra.assembly.amd.autogen.rdna4.enum').VOPDOp
|
||||
else:
|
||||
VOPDOpCls = VOPDOp
|
||||
vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), VOPDOpCls(inst.opx).name.lower(), VOPDOpCls(inst.opy).name.lower()
|
||||
def half(n, vd, s0, vs1): return f"{n} v{vd}, {inst.lit(s0)}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" if 'mov' in n else f"{n} v{vd}, {inst.lit(s0)}, v{vs1}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}"
|
||||
return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}"
|
||||
|
||||
def _disasm_vop3p(inst: VOP3P) -> str:
|
||||
name = inst.op_name.lower()
|
||||
is_wmma, n, is_fma_mix = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name
|
||||
if is_wmma:
|
||||
sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8
|
||||
src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8)
|
||||
else: src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}"
|
||||
is_wmma, n, is_fma_mix, is_swmmac = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name, 'swmmac' in name
|
||||
is_rdna4 = 'rdna4' in inst.__class__.__module__
|
||||
if is_wmma or is_swmmac:
|
||||
if is_rdna4:
|
||||
if is_swmmac:
|
||||
if '16x16x32_iu4' in name: sc0, sc1 = 1, 2
|
||||
elif '16x16x64_iu4' in name or '16x16x32_iu8' in name or 'fp8' in name or 'bf8' in name: sc0, sc1 = 2, 4
|
||||
else: sc0, sc1 = 4, 8
|
||||
sc2 = 1
|
||||
dst_w = 4 if name.startswith('v_swmmac_f16') or name.startswith('v_swmmac_bf16') else 8
|
||||
else:
|
||||
if '16x16x16_iu4' in name: sc0 = 1
|
||||
elif '16x16x32_iu4' in name or 'iu8' in name or 'fp8' in name or 'bf8' in name: sc0 = 2
|
||||
else: sc0 = 4
|
||||
sc1 = sc0
|
||||
sc2 = 4 if (name.startswith('v_wmma_f16') or name.startswith('v_wmma_bf16')) else 8
|
||||
dst_w = sc2
|
||||
src0, src1, src2, dst = _fmt_src(inst.src0, sc0), _fmt_src(inst.src1, sc1), _fmt_src(inst.src2, sc2), _vreg(inst.vdst, dst_w)
|
||||
else:
|
||||
sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8
|
||||
src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8)
|
||||
else:
|
||||
src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}"
|
||||
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
|
||||
if is_fma_mix:
|
||||
def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s)
|
||||
src0, src1, src2 = m(src0, inst.neg & 1, inst.neg_hi & 1), m(src1, inst.neg & 2, inst.neg_hi & 2), m(src2, inst.neg & 4, inst.neg_hi & 4)
|
||||
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if inst.clmp else [])
|
||||
elif is_swmmac:
|
||||
has_index_key = '16x16x64_iu4' not in name
|
||||
mods = ([f"index_key:{inst.opsel & 1}"] if has_index_key and (inst.opsel & 1) else []) + \
|
||||
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else [])
|
||||
else:
|
||||
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \
|
||||
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
|
||||
|
|
@ -313,54 +391,89 @@ def _disasm_buf(inst: MUBUF | MTBUF) -> str:
|
|||
mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
|
||||
return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {srsrc}, {decode_src(inst.soffset)}{' ' + ' '.join(mods) if mods else ''}"
|
||||
|
||||
def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
|
||||
"""Calculate vaddr register count for MIMG sample/gather operations."""
|
||||
# 1d,2d,3d,cube,1d_arr,2d_arr,2d_msaa,2d_msaa_arr
|
||||
base = [1, 2, 3, 3, 2, 3, 3, 4][dim] # address coords
|
||||
grad = [1, 2, 3, 2, 1, 2, 2, 2][dim] # gradient coords (for derivatives)
|
||||
if 'get_resinfo' in name: return 1 # only mip level
|
||||
packed, unpacked = 0, 0
|
||||
if '_mip' in name: packed += 1
|
||||
elif 'sample' in name or 'gather' in name:
|
||||
if '_o' in name: unpacked += 1 # offset
|
||||
if re.search(r'_c(_|$)', name): unpacked += 1 # compare (not _cl)
|
||||
if '_d' in name: unpacked += (grad + 1) & ~1 if '_g16' in name else grad*2 # derivatives
|
||||
if '_b' in name: unpacked += 1 # bias
|
||||
if '_l' in name and '_cl' not in name and '_lz' not in name: packed += 1 # LOD
|
||||
if '_cl' in name: packed += 1 # clamp
|
||||
return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked
|
||||
|
||||
def _disasm_mimg(inst: MIMG) -> str:
|
||||
name = inst.op_name.lower()
|
||||
srsrc_base = inst.srsrc * 4
|
||||
srsrc_str = _sreg_or_ttmp(srsrc_base, 8)
|
||||
# BVH intersect ray: special case with 4 SGPR srsrc
|
||||
if 'bvh' in name:
|
||||
vaddr = (9 if '64' in name else 8) if inst.a16 else (12 if '64' in name else 11)
|
||||
return f"{name} {_vreg(inst.vdata, 4)}, {_vreg(inst.vaddr, vaddr)}, {_sreg_or_ttmp(srsrc_base, 4)}{' a16' if inst.a16 else ''}"
|
||||
# vdata width from dmask (gather4/msaa_load always 4), d16 packs, tfe adds 1
|
||||
vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1)
|
||||
if inst.d16: vdata = (vdata + 1) // 2
|
||||
if inst.tfe: vdata += 1
|
||||
# vaddr width
|
||||
dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
|
||||
dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"dim_{inst.dim}"
|
||||
vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16)
|
||||
vaddr_str = f"v{inst.vaddr}" if vaddr == 1 else _vreg(inst.vaddr, vaddr)
|
||||
# modifiers
|
||||
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask and (inst.dmask != 15 or 'atomic' in name) else []
|
||||
mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}")
|
||||
for flag, mod in [(inst.unrm,"unorm"),(inst.glc,"glc"),(inst.slc,"slc"),(inst.dlc,"dlc"),(inst.r128,"r128"),
|
||||
(inst.a16,"a16"),(inst.tfe,"tfe"),(inst.lwe,"lwe"),(inst.d16,"d16")]:
|
||||
for flag, mod in [(inst.unrm,"unorm"),(inst.glc,"glc"),(inst.slc,"slc"),(inst.dlc,"dlc"),(inst.r128,"r128"),(inst.a16,"a16"),(inst.tfe,"tfe"),(inst.lwe,"lwe"),(inst.d16,"d16")]:
|
||||
if flag: mods.append(mod)
|
||||
# ssamp for sample/gather/get_lod
|
||||
ssamp_str = ""
|
||||
if 'sample' in name or 'gather' in name or 'get_lod' in name:
|
||||
ssamp_str = ", " + _sreg_or_ttmp(inst.ssamp * 4, 4)
|
||||
ssamp_str = ", " + _sreg_or_ttmp(inst.ssamp * 4, 4) if 'sample' in name or 'gather' in name or 'get_lod' in name else ""
|
||||
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}"
|
||||
|
||||
def _disasm_vsample(inst) -> str:
|
||||
name = inst.op_name.lower()
|
||||
if not name: raise ValueError(f"undefined VSAMPLE op: {inst.op}")
|
||||
if 'msaa_load' in name: raise ValueError(f"image_msaa_load not supported in VSAMPLE for gfx1200")
|
||||
dim, dim_names = inst.dim, ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
|
||||
dim_str = dim_names[dim] if dim < len(dim_names) else f"dim_{dim}"
|
||||
vdata = 4 if 'gather4' in name else (bin(inst.dmask).count('1') or 1)
|
||||
if inst.d16: vdata = (vdata + 1) // 2
|
||||
if inst.tfe: vdata += 1
|
||||
vaddr_count = _mimg_vaddr_width(name, dim, inst.a16)
|
||||
if vaddr_count > 4: raise ValueError(f"{name} with dim={dim} needs {vaddr_count} vaddrs (>4, unsupported)")
|
||||
vaddr_str = _fmt_vaddr_nsa(_collect_vaddrs(inst, vaddr_count))
|
||||
srsrc_str, ssamp_str = _sreg_or_ttmp(inst.rsrc, 8), _sreg_or_ttmp(inst.samp, 4)
|
||||
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask else []
|
||||
mods.append(f"dim:SQ_RSRC_IMG_{dim_str.upper()}")
|
||||
for flag, mod in [(inst.unrm, "unorm"), (inst.r128, "r128"), (inst.a16, "a16"), (inst.tfe, "tfe"), (inst.lwe, "lwe"), (inst.d16, "d16")]:
|
||||
if flag: mods.append(mod)
|
||||
th_val, scope_val = inst.th, inst.scope
|
||||
if th_val == 3 and scope_val == 3: raise ValueError("invalid th/scope: TH_LOAD_LU with SCOPE_SYS")
|
||||
if scope_val == 2 and th_val == 0: raise ValueError("invalid scope SCOPE_SA without th")
|
||||
if inst.tfe and inst.d16 and th_val != 0: raise ValueError("invalid th with tfe+d16")
|
||||
if (th_name := _TH_LOAD.get(th_val)): mods.append(f"th:{th_name}")
|
||||
if (scope_name := _SCOPE.get(scope_val)): mods.append(f"scope:{scope_name}")
|
||||
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}, {ssamp_str} {' '.join(mods)}"
|
||||
|
||||
def _disasm_vimage(inst) -> str:
|
||||
name = inst.op_name.lower()
|
||||
if 'bvh' in name: raise ValueError(f"BVH instruction {name} not supported")
|
||||
dim, dim_names = inst.dim, ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
|
||||
dim_str = dim_names[dim] if dim < len(dim_names) else f"dim_{dim}"
|
||||
is_resinfo, is_atomic, is_store = 'resinfo' in name, 'atomic' in name, 'store' in name
|
||||
if is_atomic: vdata = (2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)
|
||||
else: vdata = 4 if 'msaa_load' in name else (bin(inst.dmask).count('1') or 1)
|
||||
if inst.d16: vdata = (vdata + 1) // 2
|
||||
if inst.tfe: vdata += 1
|
||||
if is_resinfo: vaddr_count = 1
|
||||
else:
|
||||
base_count = [1, 2, 3, 3, 2, 3, 3, 4][dim] if dim < 8 else 1
|
||||
total_coords = base_count + (1 if '_mip' in name else 0)
|
||||
vaddr_count = (total_coords + 1) // 2 if inst.a16 else total_coords
|
||||
vaddr_str = _fmt_vaddr_nsa(_collect_vaddrs(inst, vaddr_count))
|
||||
srsrc_str = _sreg_or_ttmp(inst.rsrc, 8)
|
||||
mods = [f"dmask:0x3" if 'cmpswap' in name else f"dmask:0x1"] if is_atomic else [f"dmask:0x{inst.dmask:x}"]
|
||||
mods.append(f"dim:SQ_RSRC_IMG_{dim_str.upper()}")
|
||||
for flag, mod in [(inst.r128, "r128"), (inst.a16, "a16"), (inst.tfe, "tfe"), (inst.d16, "d16")]:
|
||||
if flag: mods.append(mod)
|
||||
th_val, scope_val = inst.th, inst.scope
|
||||
if th_val == 3 and scope_val == 3 and not is_atomic: raise ValueError("invalid th/scope: TH_LOAD_LU with SCOPE_SYS")
|
||||
if is_atomic and th_val > 2: raise ValueError(f"invalid th value {th_val} for atomic")
|
||||
if is_store and th_val == 3: raise ValueError("invalid TH_STORE_LU for store")
|
||||
if scope_val == 2 and th_val == 0: raise ValueError("invalid SCOPE_SA without th")
|
||||
if inst.tfe and inst.d16 and th_val != 0: raise ValueError("invalid th with tfe+d16")
|
||||
th_table = _TH_ATOMIC if is_atomic else (_TH_STORE if is_store else _TH_LOAD)
|
||||
if (th_name := th_table.get(th_val)): mods.append(f"th:{th_name}")
|
||||
if (scope_name := _SCOPE.get(scope_val)): mods.append(f"scope:{scope_name}")
|
||||
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str} {' '.join(mods)}"
|
||||
|
||||
def _disasm_sop1(inst: SOP1) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
if not name: raise ValueError(f"undefined SOP1 op: {inst.op}")
|
||||
if _has(name, 'alloc_vgpr', 'sleep_var', 'barrier_signal', 'barrier_wait', 'wakeup_barrier'):
|
||||
return f"{name} {inst.lit(inst.ssrc0) if inst.src_regs(0) == 1 else _fmt_src(inst.ssrc0, inst.src_regs(0))}"
|
||||
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
|
||||
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {_fmt_src(inst.ssrc0, 2)}"
|
||||
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {_fmt_src(inst.ssrc0, 2)}"
|
||||
|
|
@ -375,10 +488,14 @@ def _disasm_sopc(inst: SOPC) -> str:
|
|||
|
||||
def _disasm_sopk(inst: SOPK) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
if not name: raise ValueError(f"undefined SOPK op: {inst.op}")
|
||||
if op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}"
|
||||
if op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32):
|
||||
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
|
||||
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
|
||||
is_rdna4 = 'rdna4' in inst.__class__.__module__
|
||||
hwreg_map = HWREG_RDNA4 if is_rdna4 else HWREG
|
||||
if hid in (16, 17) or (is_rdna4 and hid not in hwreg_map): hs = f"0x{inst.simm16:x}"
|
||||
else: hs = f"hwreg({hwreg_map.get(hid, str(hid))}, {hoff}, {hsz})"
|
||||
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if op == SOPKOp.S_SETREG_B32 else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}"
|
||||
|
||||
|
|
@ -386,43 +503,36 @@ def _disasm_vinterp(inst: VINTERP) -> str:
|
|||
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
|
||||
return f"{inst.op_name.lower()} v{inst.vdst}, {inst.lit(inst.src0, inst.neg & 1)}, {inst.lit(inst.src1, inst.neg & 2)}, {inst.lit(inst.src2, inst.neg & 4)}" + (" " + mods if mods else "")
|
||||
|
||||
# RDNA4-specific handlers
|
||||
def _disasm_ldsdir(inst) -> str:
|
||||
wait = f" wait_vdst:{inst.wait_va}" if inst.wait_va != 0 else ""
|
||||
if inst.op == 1: return f"lds_direct_load v{inst.vdst}{wait}"
|
||||
if inst.op == 0: return f"lds_param_load v{inst.vdst}, attr{inst.attr}.{['x','y','z','w'][inst.attr_chan]}{wait}"
|
||||
raise ValueError(f"unknown LDSDIR op: {inst.op}")
|
||||
|
||||
def _disasm_vexport(inst) -> str:
|
||||
"""Disassemble VEXPORT (export) instruction - RDNA4 format."""
|
||||
targets = ['mrt0','mrt1','mrt2','mrt3','mrt4','mrt5','mrt6','mrt7','mrtz','null','prim','pos0','pos1','pos2','pos3','pos4','param0','param1','param2','param3','param4','param5','param6','param7','param8','param9','param10','param11','param12','param13','param14','param15','param16','param17','param18','param19','param20','param21','param22','param23','param24','param25','param26','param27','param28','param29','param30','param31']
|
||||
target = targets[inst.target] if inst.target < len(targets) else f"invalid_target_{inst.target}"
|
||||
target = _EXP_TARGETS.get(inst.target, f"invalid_target_{inst.target}")
|
||||
en = inst.en
|
||||
vsrc = lambda i, v: f"v{v}" if (en >> i) & 1 else "off"
|
||||
srcs = f"{vsrc(0, inst.vsrc0)}, {vsrc(1, inst.vsrc1)}, {vsrc(2, inst.vsrc2)}, {vsrc(3, inst.vsrc3)}"
|
||||
mods = _mods((inst.done, "done"), (inst.row, "row_en"))
|
||||
return f"export {target} {srcs}" + (" " + mods if mods else "")
|
||||
prefix = "export" if 'rdna4' in inst.__class__.__module__ else "exp"
|
||||
return f"{prefix} {target} {srcs}" + (" " + mods if mods else "")
|
||||
|
||||
def _disasm_vbuffer(inst) -> str:
|
||||
"""Disassemble VBUFFER instruction - RDNA4 96-bit buffer format."""
|
||||
name = inst.op_name.lower()
|
||||
# Data width: derive from instruction name (d16 packs 2 components per register)
|
||||
suffix = name.split('_')[-1]
|
||||
base_w = {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'b8':1,'x':1,'xy':2,'xyz':3,'xyzw':4,'u32':1,'u64':2,'i32':1,'i64':2,'f32':1,'f64':2,'f16':1,'bf16':1}.get(suffix, 1)
|
||||
if 'd16' in name:
|
||||
w = (base_w + 1) // 2 # d16: pack 2 values per register (xy=1, xyz=2, xyzw=2)
|
||||
else:
|
||||
w = base_w
|
||||
w = (base_w + 1) // 2 if 'd16' in name else base_w
|
||||
if 'cmpswap' in name: w *= 2
|
||||
if inst.tfe: w += 1
|
||||
# vaddr
|
||||
vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off"
|
||||
# rsrc is stored directly (SGPR index), not multiplied by 4 like RDNA3 MUBUF
|
||||
rsrc = _sreg_or_ttmp(inst.rsrc, 4)
|
||||
# soffset
|
||||
soffset = decode_src(inst.soffset)
|
||||
# TH (temporal hint) and SCOPE names for RDNA4
|
||||
rsrc, soffset = _sreg_or_ttmp(inst.rsrc, 4), decode_src(inst.soffset)
|
||||
th_load = ['','TH_LOAD_RT','TH_LOAD_NT','TH_LOAD_HT','TH_LOAD_LU','TH_LOAD_NT_RT','TH_LOAD_NT_HT','TH_LOAD_BYPASS']
|
||||
th_store = ['','TH_STORE_RT','TH_STORE_NT','TH_STORE_HT','','TH_STORE_NT_RT','TH_STORE_NT_HT','TH_STORE_BYPASS']
|
||||
th_atomic = ['','TH_ATOMIC_NT','','','','TH_ATOMIC_RETURN','TH_ATOMIC_RT_RETURN','TH_ATOMIC_CASCADE_NT']
|
||||
scope_names = ['','SCOPE_SE','SCOPE_DEV','SCOPE_SYS']
|
||||
is_atomic, is_store = 'atomic' in name, 'store' in name and 'atomic' not in name
|
||||
th_names = th_atomic if is_atomic else th_store if is_store else th_load
|
||||
# Modifiers
|
||||
mods = []
|
||||
if inst.idxen: mods.append("idxen")
|
||||
if inst.offen: mods.append("offen")
|
||||
|
|
@ -430,75 +540,53 @@ def _disasm_vbuffer(inst) -> str:
|
|||
if inst.th and inst.th < len(th_names) and th_names[inst.th]: mods.append(f"th:{th_names[inst.th]}")
|
||||
if inst.scope and inst.scope < len(scope_names) and scope_names[inst.scope]: mods.append(f"scope:{scope_names[inst.scope]}")
|
||||
if inst.tfe: mods.append("tfe")
|
||||
mod_str = " " + " ".join(mods) if mods else ""
|
||||
return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {rsrc}, {soffset}{mod_str}"
|
||||
return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {rsrc}, {soffset}" + (" " + " ".join(mods) if mods else "")
|
||||
|
||||
def _disasm_vflat(inst) -> str:
|
||||
"""Disassemble VFLAT/VGLOBAL/VSCRATCH instruction - RDNA4 96-bit flat format."""
|
||||
name = inst.op_name.lower()
|
||||
cls_name = type(inst).__name__
|
||||
# Determine segment from class name (VFLAT, VGLOBAL, VSCRATCH)
|
||||
seg = 'flat' if cls_name == 'VFLAT' else 'global' if cls_name == 'VGLOBAL' else 'scratch'
|
||||
# Build instruction name: replace prefix with segment
|
||||
parts = name.split('_', 1)
|
||||
instr = f"{seg}_{parts[1]}" if len(parts) > 1 else name
|
||||
# Data width: derive from instruction name
|
||||
suffix = name.split('_')[-1]
|
||||
w = {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'b8':1,'u32':1,'u64':2,'i32':1,'i64':2,'f32':1,'f64':2,'f16':1,'bf16':1}.get(suffix, 1)
|
||||
if 'cmpswap' in name: w *= 2
|
||||
# Signed offset (24-bit two's complement)
|
||||
off_val = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000
|
||||
# saddr handling
|
||||
if seg == 'flat':
|
||||
saddr_s = ""
|
||||
addr_width = 2
|
||||
elif inst.saddr == 0x7F or (hasattr(inst, 'sve') and inst.sve == 0 and seg == 'scratch'):
|
||||
saddr_s = ", off"
|
||||
addr_width = 2
|
||||
elif inst.saddr == 124:
|
||||
saddr_s = ", off"
|
||||
addr_width = 2
|
||||
else:
|
||||
saddr_s = f", {_fmt_src(inst.saddr, 2) if inst.saddr <= 105 else decode_src(inst.saddr)}"
|
||||
addr_width = 1
|
||||
# vaddr
|
||||
if seg == 'flat': saddr_s, addr_width = "", 2
|
||||
elif inst.saddr == 0x7F or (hasattr(inst, 'sve') and inst.sve == 0 and seg == 'scratch'): saddr_s, addr_width = ", off", 2
|
||||
elif inst.saddr == 124: saddr_s, addr_width = ", off", 2
|
||||
else: saddr_s, addr_width = f", {_fmt_src(inst.saddr, 2) if inst.saddr <= 105 else decode_src(inst.saddr)}", 1
|
||||
vaddr = f"v{inst.vaddr}" if addr_width == 1 else _vreg(inst.vaddr, 2)
|
||||
# TH and SCOPE
|
||||
th_load = ['','TH_LOAD_RT','TH_LOAD_NT','TH_LOAD_HT','TH_LOAD_LU','TH_LOAD_NT_RT','TH_LOAD_NT_HT','TH_LOAD_BYPASS']
|
||||
th_store = ['','TH_STORE_RT','TH_STORE_NT','TH_STORE_HT','','TH_STORE_NT_RT','TH_STORE_NT_HT','TH_STORE_BYPASS']
|
||||
th_atomic = ['','TH_ATOMIC_NT','','','','TH_ATOMIC_RETURN','TH_ATOMIC_RT_RETURN','TH_ATOMIC_CASCADE_NT']
|
||||
scope_names = ['','SCOPE_SE','SCOPE_DEV','SCOPE_SYS']
|
||||
is_atomic, is_store = 'atomic' in name, 'store' in name and 'atomic' not in name
|
||||
th_names = th_atomic if is_atomic else th_store if is_store else th_load
|
||||
# Modifiers
|
||||
mods = []
|
||||
if off_val: mods.append(f"offset:{off_val}")
|
||||
if inst.th and inst.th < len(th_names) and th_names[inst.th]: mods.append(f"th:{th_names[inst.th]}")
|
||||
if inst.scope and inst.scope < len(scope_names) and scope_names[inst.scope]: mods.append(f"scope:{scope_names[inst.scope]}")
|
||||
mod_str = " " + " ".join(mods) if mods else ""
|
||||
# Determine data register
|
||||
if 'store' in name and 'atomic' not in name:
|
||||
if 'store' in name and 'atomic' not in name: return f"{instr} {vaddr}, {_vreg(inst.vsrc, w)}{saddr_s}{mod_str}"
|
||||
if 'atomic' in name:
|
||||
if inst.th and inst.th >= 5: return f"{instr} {_vreg(inst.vdst, w)}, {vaddr}, {_vreg(inst.vsrc, w)}{saddr_s}{mod_str}"
|
||||
return f"{instr} {vaddr}, {_vreg(inst.vsrc, w)}{saddr_s}{mod_str}"
|
||||
elif 'atomic' in name:
|
||||
if inst.th and inst.th >= 5: # TH_ATOMIC_RETURN variants
|
||||
return f"{instr} {_vreg(inst.vdst, w)}, {vaddr}, {_vreg(inst.vsrc, w)}{saddr_s}{mod_str}"
|
||||
else:
|
||||
return f"{instr} {vaddr}, {_vreg(inst.vsrc, w)}{saddr_s}{mod_str}"
|
||||
else: # load
|
||||
return f"{instr} {_vreg(inst.vdst, w)}, {vaddr}{saddr_s}{mod_str}"
|
||||
return f"{instr} {_vreg(inst.vdst, w)}, {vaddr}{saddr_s}{mod_str}"
|
||||
|
||||
# Handler mappings
|
||||
DISASM_HANDLERS = {VOP1: _disasm_vop1, VOP2: _disasm_vop2, VOPC: _disasm_vopc, VOP3: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOPD: _disasm_vopd, VOP3P: _disasm_vop3p,
|
||||
VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, MUBUF: _disasm_buf, MTBUF: _disasm_buf,
|
||||
MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk}
|
||||
|
||||
# Name-based lookup for cross-architecture support (RDNA3/RDNA4 have different class objects with same names)
|
||||
_DISASM_BY_NAME = {
|
||||
'VOP1': _disasm_vop1, 'VOP2': _disasm_vop2, 'VOPC': _disasm_vopc, 'VOP3': _disasm_vop3, 'VOP3SD': _disasm_vop3sd,
|
||||
'VOPD': _disasm_vopd, 'VOP3P': _disasm_vop3p, 'VINTERP': _disasm_vinterp, 'SOPP': _disasm_sopp, 'SMEM': _disasm_smem,
|
||||
'DS': _disasm_ds, 'VDS': _disasm_ds, 'FLAT': _disasm_flat, 'MUBUF': _disasm_buf, 'MTBUF': _disasm_buf, 'MIMG': _disasm_mimg,
|
||||
'SOP1': _disasm_sop1, 'SOP2': _disasm_sop2, 'SOPC': _disasm_sopc, 'SOPK': _disasm_sopk,
|
||||
'VEXPORT': _disasm_vexport, 'EXP': _disasm_vexport,
|
||||
'VEXPORT': _disasm_vexport, 'EXP': _disasm_vexport, 'LDSDIR': _disasm_ldsdir,
|
||||
'VBUFFER': _disasm_vbuffer, 'VFLAT': _disasm_vflat, 'VGLOBAL': _disasm_vflat, 'VSCRATCH': _disasm_vflat,
|
||||
'VSAMPLE': _disasm_vsample, 'VIMAGE': _disasm_vimage,
|
||||
}
|
||||
|
||||
def disasm(inst: Inst) -> str:
|
||||
|
|
@ -512,7 +600,7 @@ def disasm(inst: Inst) -> str:
|
|||
|
||||
SPEC_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), 'null': RawImm(124), 'off': RawImm(124), 'm0': RawImm(125),
|
||||
'exec_lo': RawImm(126), 'exec_hi': RawImm(127), 'exec': RawImm(126), 'scc': RawImm(253), 'src_scc': RawImm(253)}
|
||||
FLOATS = {str(k): k for k in FLOAT_ENC} # Valid float literal strings: '0.5', '-0.5', '1.0', etc.
|
||||
FLOATS = {str(k): k for k in FLOAT_ENC}
|
||||
REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp}
|
||||
SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b128', 's_load_b256', 's_load_b512',
|
||||
's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512'}
|
||||
|
|
@ -554,7 +642,6 @@ def _extract(text: str, pat: str, flags=re.I):
|
|||
|
||||
def get_dsl(text: str) -> str:
|
||||
text, kw = text.strip(), []
|
||||
# Extract modifiers
|
||||
for pat, val in [(r'\s+mul:2(?:\s|$)', 1), (r'\s+mul:4(?:\s|$)', 2), (r'\s+div:2(?:\s|$)', 3)]:
|
||||
if (m := _extract(text, pat))[0]: kw.append(f'omod={val}'); text = m[1]; break
|
||||
if (m := _extract(text, r'\s+clamp(?:\s|$)'))[0]: kw.append('clmp=1'); text = m[1]
|
||||
|
|
@ -572,13 +659,11 @@ def get_dsl(text: str) -> str:
|
|||
m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
|
||||
m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
|
||||
if waitexp: kw.append(f'waitexp={waitexp}')
|
||||
|
||||
parts = text.replace(',', ' ').split()
|
||||
if not parts: raise ValueError("empty instruction")
|
||||
mn, op_str = parts[0].lower(), text[len(parts[0]):].strip()
|
||||
ops, args = _parse_ops(op_str), [_op2dsl(o) for o in _parse_ops(op_str)]
|
||||
|
||||
# s_waitcnt
|
||||
if mn == 's_waitcnt':
|
||||
vm, exp, lgkm = 0x3f, 0x7, 0x3f
|
||||
for p in op_str.replace(',', ' ').split():
|
||||
|
|
@ -588,7 +673,6 @@ def get_dsl(text: str) -> str:
|
|||
elif re.match(r'^0x[0-9a-f]+$|^\d+$', p): return f"s_waitcnt(simm16={int(p, 0)})"
|
||||
return f"s_waitcnt(simm16={waitcnt(vm, exp, lgkm)})"
|
||||
|
||||
# VOPD
|
||||
if '::' in text:
|
||||
xp, yp = text.split('::')
|
||||
xps, yps = xp.strip().replace(',', ' ').split(), yp.strip().replace(',', ' ').split()
|
||||
|
|
@ -600,14 +684,12 @@ def get_dsl(text: str) -> str:
|
|||
elif 'fmamk' in yps[0].lower() and len(yo) > 3: lit, vsy1 = yo[2], yo[3]
|
||||
return f"VOPD(VOPDOp.{xps[0].upper()}, VOPDOp.{yps[0].upper()}, vdstx={vdx}, vdsty={vdy}, srcx0={sx0}, vsrcx1={vsx1}, srcy0={sy0}, vsrcy1={vsy1}{f', literal={lit}' if lit else ''})"
|
||||
|
||||
# Special instructions
|
||||
if mn == 's_setreg_imm32_b32': raise ValueError(f"unsupported: {mn}")
|
||||
if mn in ('s_setpc_b64', 's_rfe_b64'): return f"{mn}(ssrc0={args[0]})"
|
||||
if mn in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'): return f"{mn}(sdst={args[0]}, ssrc0=RawImm({args[1].strip()}))"
|
||||
if mn == 's_version': return f"{mn}(simm16={args[0]})"
|
||||
if mn == 's_setreg_b32': return f"{mn}(simm16={args[0]}, sdst={args[1]})"
|
||||
|
||||
# SMEM
|
||||
if mn in SMEM_OPS:
|
||||
gs, ds = ", glc=1" if glc else "", ", dlc=1" if dlc else ""
|
||||
if len(ops) >= 3 and re.match(r'^-?[0-9]|^-?0x', ops[2].strip().lower()):
|
||||
|
|
@ -615,11 +697,9 @@ def get_dsl(text: str) -> str:
|
|||
if off_val and len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}{gs}{ds})"
|
||||
if len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, soffset={args[2]}{gs}{ds})"
|
||||
|
||||
# Buffer
|
||||
if mn.startswith('buffer_') and len(ops) >= 2 and ops[1].strip().lower() == 'off':
|
||||
return f"{mn}(vdata={args[0]}, vaddr=0, srsrc={args[2]}, soffset={f'RawImm({args[3].strip()})' if len(args) > 3 else 'RawImm(0)'})"
|
||||
|
||||
# FLAT/GLOBAL/SCRATCH load/store/atomic - saddr needs RawImm(124) for off/null
|
||||
def _saddr(a): return 'RawImm(124)' if a in ('OFF', 'NULL') else a
|
||||
flat_mods = f"{f', offset={off_val}' if off_val else ''}{', glc=1' if glc else ''}{', slc=1' if slc else ''}{', dlc=1' if dlc else ''}"
|
||||
for pre, flds in [('flat_load','vdst,addr,saddr'), ('global_load','vdst,addr,saddr'), ('scratch_load','vdst,addr,saddr'),
|
||||
|
|
@ -632,7 +712,6 @@ def get_dsl(text: str) -> str:
|
|||
if glc and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3])}' if len(args) >= 4 else ', saddr=RawImm(124)'}{flat_mods})"
|
||||
if len(args) >= 2: return f"{mn}(addr={args[0]}, data={args[1]}{f', saddr={_saddr(args[2])}' if len(args) >= 3 else ', saddr=RawImm(124)'}{flat_mods})"
|
||||
|
||||
# DS instructions
|
||||
if mn.startswith('ds_'):
|
||||
off0, off1 = (str(int(off_val, 0) & 0xff), str((int(off_val, 0) >> 8) & 0xff)) if off_val else ("0", "0")
|
||||
gds_s = ", gds=1" if 'gds' in text.lower().split()[-1:] else ""
|
||||
|
|
@ -656,12 +735,10 @@ def get_dsl(text: str) -> str:
|
|||
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
|
||||
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})"
|
||||
|
||||
# v_fmaak/v_fmamk literal extraction
|
||||
lit_s = ""
|
||||
if mn in ('v_fmaak_f32', 'v_fmaak_f16') and len(args) == 4: lit_s, args = f", literal={args[3].strip()}", args[:3]
|
||||
elif mn in ('v_fmamk_f32', 'v_fmamk_f16') and len(args) == 4: lit_s, args = f", literal={args[2].strip()}", [args[0], args[1], args[3]]
|
||||
|
||||
# VCC ops cleanup
|
||||
vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'}
|
||||
if mn.replace('_e32', '') in vcc_ops and len(args) >= 5: mn, args = mn.replace('_e32', '') + '_e32', [args[0], args[2], args[3]]
|
||||
if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '')
|
||||
|
|
@ -671,7 +748,6 @@ def get_dsl(text: str) -> str:
|
|||
fn = mn.replace('.', '_')
|
||||
if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args]
|
||||
|
||||
# v_fma_mix*: extract inline neg/abs modifiers
|
||||
if 'fma_mix' in mn and neg_lo is None and neg_hi is None:
|
||||
inline_neg, inline_abs, clean_args = 0, 0, [args[0]]
|
||||
for i, op in enumerate(ops[1:4]):
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class VIMAGE(Inst96):
|
|||
scope = bits[51:50]
|
||||
th = bits[54:52]
|
||||
tfe = bits[55]
|
||||
vaddr4 = bits[56:63]
|
||||
vaddr4 = bits[63:56]
|
||||
vaddr0 = bits[71:64]
|
||||
vaddr1 = bits[79:72]
|
||||
vaddr2 = bits[87:80]
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ _SPECIAL_REGS = {
|
|||
'V_CMP_CLASS_F64': (1, 2, 1, 1), 'V_CMPX_CLASS_F64': (1, 2, 1, 1),
|
||||
'V_CMP_CLASS_F32': (1, 1, 1, 1), 'V_CMPX_CLASS_F32': (1, 1, 1, 1),
|
||||
'V_CMP_CLASS_F16': (1, 1, 1, 1), 'V_CMPX_CLASS_F16': (1, 1, 1, 1),
|
||||
'V_MAD_U64_U32': (2, 1, 1, 2), 'V_MAD_I64_I32': (2, 1, 1, 2),
|
||||
'V_MAD_U64_U32': (2, 1, 1, 2), 'V_MAD_I64_I32': (2, 1, 1, 2), 'V_MAD_CO_U64_U32': (2, 1, 1, 2), 'V_MAD_CO_I64_I32': (2, 1, 1, 2),
|
||||
'V_QSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_U32_U8': (4, 2, 1, 4),
|
||||
# RDNA4 CVT_PK_F32 instructions output 2 F32 values (64-bit)
|
||||
'V_CVT_PK_F32_BF8': (2, 1, 1, 1), 'V_CVT_PK_F32_FP8': (2, 1, 1, 1),
|
||||
|
|
@ -98,7 +98,7 @@ def spec_is_16bit(name: str) -> bool:
|
|||
def spec_is_64bit(name: str) -> bool: return bool(re.search(r'_[FIUB]64(?:_|$)', name.upper()))
|
||||
_3SRC = {'FMA', 'MAD', 'MIN3', 'MAX3', 'MED3', 'DIV_FIX', 'DIV_FMAS', 'DIV_SCALE', 'SAD', 'LERP', 'ALIGN', 'CUBE', 'BFE', 'BFI',
|
||||
'PERM_B32', 'PERMLANE', 'CNDMASK', 'XOR3', 'OR3', 'ADD3', 'LSHL_OR', 'AND_OR', 'LSHL_ADD', 'ADD_LSHL', 'XAD', 'MAXMIN',
|
||||
'MINMAX', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'}
|
||||
'MINMAX', 'MAXIMUMMINIMUM', 'MINIMUMMAXIMUM', 'MAXIMUM3', 'MINIMUM3', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'}
|
||||
_2SRC = {'FMAC'} # FMAC uses dst as implicit accumulator, so only 2 explicit sources
|
||||
def spec_num_srcs(name: str) -> int:
|
||||
name = name.upper()
|
||||
|
|
@ -138,9 +138,14 @@ class BitField:
|
|||
# Convert to IntEnum if marker is an IntEnum subclass
|
||||
if self.marker and isinstance(self.marker, type) and issubclass(self.marker, IntEnum):
|
||||
# VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
|
||||
if self.marker is VOP3Op:
|
||||
if val < 256: return VOPCOp(val)
|
||||
if val in Inst._VOP3SD_OPS: return VOP3SDOp(val)
|
||||
# Check by name to handle both RDNA3 and RDNA4 enums
|
||||
if self.marker.__name__ == 'VOP3Op':
|
||||
# Get the appropriate enums from the same module as the marker
|
||||
marker_mod = self.marker.__module__
|
||||
import importlib
|
||||
enum_mod = importlib.import_module(marker_mod)
|
||||
if val < 256: return enum_mod.VOPCOp(val)
|
||||
if val in Inst._VOP3SD_OPS: return enum_mod.VOP3SDOp(val)
|
||||
try: return self.marker(val)
|
||||
except ValueError: pass
|
||||
# For RDNA4 op fields without type annotations, dynamically look up enum
|
||||
|
|
|
|||
|
|
@ -38,6 +38,9 @@ RDNA3_TEST_FILES = {
|
|||
'mubuf': 'gfx11_asm_mubuf.s',
|
||||
'mtbuf': 'gfx11_asm_mtbuf.s',
|
||||
'mimg': 'gfx11_asm_mimg.s',
|
||||
'ldsdir': 'gfx11_asm_ldsdir.s',
|
||||
# Export
|
||||
'exp': 'gfx11_asm_exp.s',
|
||||
# WMMA
|
||||
'wmma': 'gfx11_asm_wmma.s',
|
||||
# Features
|
||||
|
|
@ -162,12 +165,12 @@ class TestLLVMRDNA3(unittest.TestCase):
|
|||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
from extra.assembly.amd.autogen.rdna3.ins import SOP1, SOP2, SOPC, SOPK, SOPP, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, DS, SMEM, FLAT, MUBUF, MTBUF, MIMG
|
||||
from extra.assembly.amd.autogen.rdna3.ins import SOP1, SOP2, SOPC, SOPK, SOPP, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, DS, SMEM, FLAT, MUBUF, MTBUF, MIMG, LDSDIR, EXP
|
||||
cls.formats = {
|
||||
'sop1': SOP1, 'sop2': SOP2, 'sopc': SOPC, 'sopk': SOPK, 'sopp': SOPP,
|
||||
'vop1': VOP1, 'vop2': VOP2, 'vopc': VOPC, 'vopcx': VOPC, 'vop3': VOP3, 'vop3p': VOP3P,
|
||||
'vinterp': VINTERP, 'vopd': VOPD, 'ds': DS, 'smem': SMEM, 'flat': FLAT,
|
||||
'mubuf': MUBUF, 'mtbuf': MTBUF, 'mimg': MIMG, 'wmma': VOP3P,
|
||||
'mubuf': MUBUF, 'mtbuf': MTBUF, 'mimg': MIMG, 'wmma': VOP3P, 'ldsdir': LDSDIR, 'exp': EXP,
|
||||
'vop3_from_vop1': VOP3, 'vop3_from_vop2': VOP3, 'vop3_from_vopc': VOP3, 'vop3_from_vopcx': VOP3,
|
||||
'vop3_features': VOP3, 'vop3p_features': VOP3P, 'vopd_features': VOPD,
|
||||
'vop3_alias': VOP3, 'vop3p_alias': VOP3P, 'vopc_alias': VOPC, 'vopcx_alias': VOPC,
|
||||
|
|
@ -260,7 +263,7 @@ class TestLLVMRDNA4(unittest.TestCase):
|
|||
'ds': VDS, 'ds_alias': VDS, 'smem': SMEM,
|
||||
'vinterp': VINTERP, 'exp': VEXPORT,
|
||||
'vbuffer_mubuf': VBUFFER, 'vbuffer_mubuf_alias': VBUFFER, 'vbuffer_mtbuf': VBUFFER, 'vbuffer_mtbuf_alias': VBUFFER,
|
||||
'vdsdir': VDSDIR, 'vdsdir_alias': VDSDIR,
|
||||
'vdsdir': None, 'vdsdir_alias': None, # VDSDIR is 64-bit but ds_direct_load is 32-bit
|
||||
'vflat': VFLAT, 'vflat_alias': VFLAT,
|
||||
'vimage': VIMAGE, 'vimage_alias': VIMAGE,
|
||||
'vsample': VSAMPLE,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue