mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
8 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
103a00d4c5 | ||
|
|
8c14d9f427 | ||
|
|
4e03b3ebef | ||
|
|
4571979fac | ||
|
|
9302f38f5b | ||
|
|
2a6904029b | ||
|
|
14bc1b0c68 | ||
|
|
c9b074639e |
6 changed files with 1204 additions and 389 deletions
|
|
@ -1,12 +1,15 @@
|
|||
# RDNA3 assembler and disassembler
|
||||
from __future__ import annotations
|
||||
import re
|
||||
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
|
||||
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory, SRC_FIELDS, unwrap
|
||||
from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF
|
||||
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, FLOAT_DEC, FLOAT_ENC, decode_src
|
||||
from extra.assembly.amd.autogen.rdna3 import ins
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP,
|
||||
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp)
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP, LDSDIR,
|
||||
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, VINTERPOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp)
|
||||
|
||||
# VOP3SD opcodes that share VOP3 encoding
|
||||
VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
|
||||
|
||||
def _matches_encoding(word: int, cls: type[Inst]) -> bool:
|
||||
"""Check if word matches the encoding pattern of an instruction class."""
|
||||
|
|
@ -26,7 +29,7 @@ def detect_format(data: bytes) -> type[Inst]:
|
|||
if (word >> 30) == 0b11:
|
||||
for cls in _FORMATS_64:
|
||||
if _matches_encoding(word, cls):
|
||||
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in Inst._VOP3SD_OPS else cls
|
||||
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in VOP3SD_OPS else cls
|
||||
raise ValueError(f"unknown 64-bit format word={word:#010x}")
|
||||
# 32-bit formats
|
||||
for cls in _FORMATS_32:
|
||||
|
|
@ -37,10 +40,15 @@ def detect_format(data: bytes) -> type[Inst]:
|
|||
# CONSTANTS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# GFX11 HWREG IDs
|
||||
HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_HW_ID', 5: 'HW_REG_GPR_ALLOC',
|
||||
6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS', 15: 'HW_REG_SH_MEM_BASES', 18: 'HW_REG_PERF_SNAPSHOT_PC_LO',
|
||||
19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK',
|
||||
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'}
|
||||
# GFX12 HWREG IDs - use names that LLVM recognizes
|
||||
HWREG_GFX12 = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 5: 'HW_REG_GPR_ALLOC', 6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS',
|
||||
18: 'HW_REG_EXCP_FLAG_USER', 19: 'HW_REG_TRAP_CTRL', 20: 'HW_REG_SCRATCH_BASE_LO', 21: 'HW_REG_SCRATCH_BASE_HI',
|
||||
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 29: 'HW_REG_SHADER_CYCLES_LO', 30: 'HW_REG_SHADER_CYCLES_HI'}
|
||||
HWREG_IDS = {v.lower(): k for k, v in HWREG.items()}
|
||||
MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA',
|
||||
131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'}
|
||||
|
|
@ -76,6 +84,8 @@ def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int:
|
|||
return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
|
||||
|
||||
def _has(op: str, *subs) -> bool: return any(s in op for s in subs)
|
||||
def _is16(op: str) -> bool: return _has(op, 'f16', 'i16', 'u16', 'b16') and not _has(op, '_f32', '_i32')
|
||||
def _is64(op: str) -> bool: return _has(op, 'f64', 'i64', 'u64', 'b64')
|
||||
def _omod(v: int) -> str: return {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(v, "")
|
||||
def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v) # format 16-bit src: vgpr.h/l or literal
|
||||
def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c)
|
||||
|
|
@ -100,43 +110,72 @@ def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
|
|||
# DISASSEMBLER
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
_VOP1_F64 = {VOP1Op.V_CEIL_F64, VOP1Op.V_FLOOR_F64, VOP1Op.V_FRACT_F64, VOP1Op.V_FREXP_MANT_F64, VOP1Op.V_RCP_F64, VOP1Op.V_RNDNE_F64, VOP1Op.V_RSQ_F64, VOP1Op.V_SQRT_F64, VOP1Op.V_TRUNC_F64}
|
||||
|
||||
def _disasm_vop1(inst: VOP1) -> str:
|
||||
name = inst.op_name.lower()
|
||||
if inst.op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name
|
||||
if inst.op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}"
|
||||
# 16-bit dst: uses .h/.l suffix (determined by name pattern, not dtype - e.g. sat_pk_u8_i16 outputs 8-bit but uses 16-bit encoding)
|
||||
# Use architecture-specific op enum
|
||||
if 'rdna4' in inst.__class__.__module__:
|
||||
from extra.assembly.amd.autogen.rdna4.enum import VOP1Op as OpEnum
|
||||
else:
|
||||
OpEnum = VOP1Op
|
||||
op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower()
|
||||
if name in ('v_nop', 'v_pipeflush'): return name
|
||||
if name == 'v_readfirstlane_b32': return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}"
|
||||
parts = name.split('_')
|
||||
is_f64_d = 'f64' in name and any(x in name for x in ['ceil', 'floor', 'fract', 'frexp_mant', 'rcp', 'rndne', 'rsq', 'sqrt', 'trunc', 'cvt_f64_f32', 'cvt_f64_i32', 'cvt_f64_u32'])
|
||||
is_f64_s = 'f64' in name and any(x in name for x in ['ceil', 'floor', 'fract', 'frexp_mant', 'rcp', 'rndne', 'rsq', 'sqrt', 'trunc', 'cvt_f32_f64', 'cvt_i32_f64', 'cvt_u32_f64', 'frexp_exp_i32_f64'])
|
||||
# v_cvt_pk_f32_bf8/fp8 output 2 VGPRs (packed f32x2) and take packed 8-bit (16-bit VGPR with .l/.h) source
|
||||
is_pk_f32 = 'cvt_pk_f32_bf8' in name or 'cvt_pk_f32_fp8' in name
|
||||
is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
|
||||
src = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
|
||||
# Only packed bf8/fp8 (cvt_pk_*) use 16-bit VGPR encoding; non-packed versions use regular VGPRs
|
||||
is_16s = (parts[-1] in ('f16','i16','u16','b16') and 'sat_pk' not in name) or (parts[-1] in ('bf8', 'fp8') and 'pk' in name)
|
||||
dst = _vreg(inst.vdst, 2) if is_f64_d or is_pk_f32 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
|
||||
src = _fmt_src(inst.src0, 2) if is_f64_s else _src16(inst, inst.src0) if is_16s else inst.lit(inst.src0)
|
||||
return f"{name}_e32 {dst}, {src}"
|
||||
|
||||
def _disasm_vop2(inst: VOP2) -> str:
|
||||
name = inst.op_name.lower()
|
||||
suf = "" if inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32"
|
||||
# Use architecture-specific op enum
|
||||
if 'rdna4' in inst.__class__.__module__:
|
||||
from extra.assembly.amd.autogen.rdna4.enum import VOP2Op as OpEnum
|
||||
else:
|
||||
OpEnum = VOP2Op
|
||||
op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower()
|
||||
suf, is16 = "" if name == 'v_dot2acc_f32_f16' else "_e32", _is16(name) and 'pk_' not in name
|
||||
is64 = _is64(name)
|
||||
# For shift ops with b64, src0 is 32-bit (shift amount), dst/vsrc1 are 64-bit
|
||||
is_shift64 = 'lshlrev_b64' in name
|
||||
# fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1
|
||||
if inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}"
|
||||
if inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}"
|
||||
if inst.is_16bit(): return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if inst.op == VOP2Op.V_CNDMASK_B32 else "")
|
||||
if 'fmaak' in name: return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}"
|
||||
if 'fmamk' in name: return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}"
|
||||
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
|
||||
if is_shift64: return f"{name}{suf} {_vreg(inst.vdst, 2)}, {inst.lit(inst.src0)}, {_vreg(inst.vsrc1, 2)}"
|
||||
if is64: return f"{name}{suf} {_vreg(inst.vdst, 2)}, {_fmt_src(inst.src0, 2)}, {_vreg(inst.vsrc1, 2)}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if name == 'v_cndmask_b32' else "")
|
||||
|
||||
def _disasm_vopc(inst: VOPC) -> str:
|
||||
name = inst.op_name.lower()
|
||||
s0 = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
|
||||
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}"
|
||||
return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
|
||||
# Use architecture-specific op enum
|
||||
if 'rdna4' in inst.__class__.__module__:
|
||||
from extra.assembly.amd.autogen.rdna4.enum import VOPCOp as OpEnum
|
||||
else:
|
||||
OpEnum = VOPCOp
|
||||
op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower()
|
||||
is64, is16 = _is64(name), _is16(name)
|
||||
is_class = 'class' in name
|
||||
s0 = _fmt_src(inst.src0, 2) if is64 else _src16(inst, inst.src0) if is16 else inst.lit(inst.src0)
|
||||
s1 = _vreg(inst.vsrc1, 2) if is64 and not is_class else _fmt_v16(inst.vsrc1, 0, 128) if is16 else f"v{inst.vsrc1}"
|
||||
return f"{name}_e32 {s0}, {s1}" if op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
|
||||
|
||||
NO_ARG_SOPP = {SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
|
||||
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE}
|
||||
|
||||
def _disasm_sopp(inst: SOPP) -> str:
|
||||
name = inst.op_name.lower()
|
||||
if inst.op in NO_ARG_SOPP: return name
|
||||
if inst.op == SOPPOp.S_WAITCNT:
|
||||
op, name = SOPPOp(inst.op), SOPPOp(inst.op).name.lower()
|
||||
if op in NO_ARG_SOPP: return name
|
||||
if op == SOPPOp.S_WAITCNT:
|
||||
vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f
|
||||
p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
|
||||
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
|
||||
if inst.op == SOPPOp.S_DELAY_ALU:
|
||||
if op == SOPPOp.S_DELAY_ALU:
|
||||
deps, skips = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3'], ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4']
|
||||
id0, skip, id1 = inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x7, (inst.simm16 >> 7) & 0xf
|
||||
dep = lambda v: deps[v-1] if 0 < v <= len(deps) else str(v)
|
||||
|
|
@ -145,20 +184,51 @@ def _disasm_sopp(inst: SOPP) -> str:
|
|||
return f"{name} {inst.simm16}" if name.startswith(('s_cbranch', 's_branch')) else f"{name} 0x{inst.simm16:x}"
|
||||
|
||||
def _disasm_smem(inst: SMEM) -> str:
|
||||
name = inst.op_name.lower()
|
||||
if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
|
||||
off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset)
|
||||
sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op.value <= 12 or name == 's_atc_probe_buffer') else 2
|
||||
sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
|
||||
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}"
|
||||
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
|
||||
is_rdna4 = 'rdna4' in inst.__class__.__module__
|
||||
if is_rdna4:
|
||||
from extra.assembly.amd.autogen.rdna4.enum import SMEMOp as SMEMOp4
|
||||
op = SMEMOp4(inst.op)
|
||||
name = op.name.lower()
|
||||
if op == SMEMOp4.S_DCACHE_INV: return name
|
||||
# RDNA4: s_buffer_* uses 4-SGPR descriptor, s_load/s_prefetch uses 2-SGPR
|
||||
is_buffer = 'buffer' in name
|
||||
sbase_idx = inst.sbase * 2
|
||||
sbase_count = 4 if is_buffer else 2
|
||||
sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
|
||||
# Format offset - ioffset is signed 24-bit, show as hex
|
||||
ioff = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000 # sign extend
|
||||
off_hex = f"0x{ioff & 0xffffff:x}" if ioff >= 0 else f"-0x{(-ioff) & 0xffffff:x}"
|
||||
off_s = f"{decode_src(inst.soffset)} offset:{off_hex}" if inst.soffset != 124 else off_hex
|
||||
# Data width from opcode
|
||||
width_map = {0:1, 1:2, 2:4, 3:8, 4:16, 5:3, 8:1, 9:1, 10:1, 11:1, 16:1, 17:2, 18:4, 19:8, 20:16, 21:3, 24:1, 25:1, 26:1, 27:1}
|
||||
width = width_map.get(inst.op, 1)
|
||||
if 'prefetch' in name:
|
||||
# Prefetch has different format: s_prefetch_* sbase, offset, soffset, length
|
||||
# But we need to handle various prefetch types differently
|
||||
if name == 's_prefetch_inst_pc_rel' or name == 's_prefetch_data_pc_rel':
|
||||
return f"{name} {off_hex}, {decode_src(inst.soffset)}, {inst.sdata}"
|
||||
return f"{name} {sbase_str}, {off_hex}, {decode_src(inst.soffset)}, {inst.sdata}"
|
||||
return f"{name} {_fmt_sdst(inst.sdata, width)}, {sbase_str}, {off_s}"
|
||||
else:
|
||||
op = SMEMOp(inst.op)
|
||||
name = op.name.lower()
|
||||
if op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
|
||||
off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset)
|
||||
sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op <= 12 or name == 's_atc_probe_buffer') else 2
|
||||
sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
|
||||
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}"
|
||||
width = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(inst.op, 1)
|
||||
return f"{name} {_fmt_sdst(inst.sdata, width)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
|
||||
|
||||
def _disasm_flat(inst: FLAT) -> str:
|
||||
name = inst.op_name.lower()
|
||||
name = FLATOp(inst.op).name.lower()
|
||||
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
|
||||
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
|
||||
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
|
||||
w = inst.dst_regs() * (2 if 'cmpswap' in name else 1)
|
||||
suffix = name.split('_')[-1]
|
||||
w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'f32':1,'f64':2}.get(suffix, 1)
|
||||
if 'cmpswap' in name: w *= 2
|
||||
if name.endswith('_x2') or 'x2' in suffix: w = max(w, 2)
|
||||
mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
|
||||
# saddr
|
||||
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
|
||||
|
|
@ -178,11 +248,11 @@ def _disasm_flat(inst: FLAT) -> str:
|
|||
return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
|
||||
|
||||
def _disasm_ds(inst: DS) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
op, name = DSOp(inst.op), DSOp(inst.op).name.lower()
|
||||
gds = " gds" if inst.gds else ""
|
||||
off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else ""
|
||||
off2 = f" offset0:{inst.offset0} offset1:{inst.offset1}" if inst.offset0 or inst.offset1 else ""
|
||||
w = inst.dst_regs()
|
||||
w = 4 if '128' in name else 3 if '96' in name else 2 if (name.endswith('64') or 'gs_reg' in name) else 1
|
||||
d0, d1, dst, addr = _vreg(inst.data0, w), _vreg(inst.data1, w), _vreg(inst.vdst, w), f"v{inst.addr}"
|
||||
|
||||
if op == DSOp.DS_NOP: return name
|
||||
|
|
@ -206,34 +276,56 @@ def _disasm_ds(inst: DS) -> str:
|
|||
return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}"
|
||||
|
||||
def _disasm_vop3(inst: VOP3) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
is_rdna4 = 'rdna4' in inst.__class__.__module__
|
||||
if is_rdna4:
|
||||
from extra.assembly.amd.autogen.rdna4.enum import VOP3Op as VOP3Op4, VOP3SDOp as VOP3SDOp4
|
||||
op = VOP3SDOp4(inst.op) if inst.op in VOP3SD_OPS else VOP3Op4(inst.op)
|
||||
else:
|
||||
op = VOP3SDOp(inst.op) if inst.op in VOP3SD_OPS else VOP3Op(inst.op)
|
||||
name = op.name.lower()
|
||||
|
||||
# VOP3SD (shared encoding)
|
||||
if isinstance(op, VOP3SDOp):
|
||||
if inst.op in VOP3SD_OPS:
|
||||
sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs
|
||||
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
|
||||
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
|
||||
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
|
||||
return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {srcs}" + _omod(inst.omod)
|
||||
is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32', 'mad_co_i64_i32', 'mad_co_u64_u32')
|
||||
def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s
|
||||
s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64)
|
||||
dst = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}"
|
||||
if op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}"
|
||||
if op in (VOP3SDOp.V_ADD_CO_CI_U32, VOP3SDOp.V_SUB_CO_CI_U32, VOP3SDOp.V_SUBREV_CO_CI_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}"
|
||||
return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}" + _omod(inst.omod)
|
||||
|
||||
# Detect 16-bit operand sizes (for .h/.l suffix handling)
|
||||
# Detect operand sizes
|
||||
is64 = _is64(name)
|
||||
is64_src, is64_dst = False, False
|
||||
is16_d = is16_s = is16_s2 = False
|
||||
if 'cvt_pk' in name: is16_s = name.endswith('16')
|
||||
# v_cvt_pk_f32_bf8/fp8 outputs a VGPR pair (f32x2) from 16-bit packed input
|
||||
if 'cvt_pk_f32_bf8' in name or 'cvt_pk_f32_fp8' in name: is64_dst = True
|
||||
elif 'cvt_pk' in name: is16_s = name.endswith('16')
|
||||
elif m := re.match(r'v_(?:cvt|frexp_exp)_([a-z0-9_]+)_([a-z0-9]+)', name):
|
||||
is16_d, is16_s = _has(m.group(1), 'f16','i16','u16','b16'), _has(m.group(2), 'f16','i16','u16','b16')
|
||||
is16_s2 = is16_s
|
||||
is64_src, is64_dst = '64' in m.group(2), '64' in m.group(1)
|
||||
is16_s2, is64 = is16_s, False
|
||||
elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True
|
||||
elif 'pack_b32' in name: is16_s = is16_s2 = True
|
||||
else: is16_d = is16_s = is16_s2 = inst.is_16bit()
|
||||
else: is16_d = is16_s = is16_s2 = _is16(name) and not _has(name, 'dot2', 'pk_', 'sad', 'msad', 'qsad', 'mqsad')
|
||||
|
||||
# Source counts
|
||||
shift64 = 'rev' in name and '64' in name and name.startswith('v_')
|
||||
ldexp64 = op == VOP3Op.V_LDEXP_F64
|
||||
trig = op == VOP3Op.V_TRIG_PREOP_F64
|
||||
sad64, mqsad = _has(name, 'qsad_pk', 'mqsad_pk'), 'mqsad_u32' in name
|
||||
s0n = 2 if ((is64 and not shift64) or sad64 or mqsad or is64_src) else 1
|
||||
s1n = 2 if (is64 and not _has(name, 'class') and not ldexp64 and not trig) else 1
|
||||
s2n = 4 if mqsad else 2 if (is64 or sad64) else 1
|
||||
|
||||
any_hi = inst.opsel != 0
|
||||
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s, any_hi)
|
||||
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s, any_hi)
|
||||
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2, any_hi)
|
||||
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, s0n, is16_s, any_hi)
|
||||
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, s1n, is16_s, any_hi)
|
||||
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, s2n, is16_s2, any_hi)
|
||||
|
||||
# Destination
|
||||
dn = inst.dst_regs()
|
||||
dn = 4 if mqsad else 2 if (is64 or sad64 or is64_dst) else 1
|
||||
if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1)
|
||||
elif dn > 1: dst = _vreg(inst.vdst, dn)
|
||||
elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" if any_hi else f"v{inst.vdst}"
|
||||
|
|
@ -243,54 +335,126 @@ def _disasm_vop3(inst: VOP3) -> str:
|
|||
nonvgpr_opsel = (inst.src0 < 256 and (inst.opsel & 1)) or (inst.src1 < 256 and (inst.opsel & 2)) or (inst.src2 < 256 and (inst.opsel & 4))
|
||||
need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s)
|
||||
|
||||
# RDNA4 v_s_* instructions (pseudo-scalar VOP1-like) have SGPR destination
|
||||
if name.startswith('v_s_') and is_rdna4:
|
||||
return f"{name} {_fmt_sdst(inst.vdst, 1)}, {s0}{cl}{om}"
|
||||
|
||||
if inst.op < 256: # VOPC
|
||||
return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
|
||||
if inst.op < 384: # VOP2
|
||||
n = inst.num_srcs()
|
||||
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
|
||||
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}"
|
||||
os = _opsel_str(inst.opsel, 3, need_opsel, is16_d) if 'cndmask' in name else _opsel_str(inst.opsel, 2, need_opsel, is16_d)
|
||||
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if 'cndmask' in name else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}"
|
||||
if inst.op < 512: # VOP1
|
||||
return f"{name}_e64" if op in (VOP3Op.V_NOP, VOP3Op.V_PIPEFLUSH) else f"{name}_e64 {dst}, {s0}{_opsel_str(inst.opsel, 1, need_opsel, is16_d)}{cl}{om}"
|
||||
if name in ('v_nop', 'v_pipeflush'): return f"{name}_e64"
|
||||
# Handle byte_sel for non-pk fp8/bf8 conversions
|
||||
if ('cvt_f32_fp8' in name or 'cvt_f32_bf8' in name) and 'pk' not in name:
|
||||
byte_sel = inst.opsel & 3
|
||||
os = f" byte_sel:{byte_sel}" if byte_sel else ""
|
||||
elif 'cvt_pk_f32_bf8' in name or 'cvt_pk_f32_fp8' in name:
|
||||
os = _opsel_str(inst.opsel, 2, need_opsel, is16_d) # 2-element for pk variants
|
||||
else:
|
||||
os = _opsel_str(inst.opsel, 1, need_opsel, is16_d) # 1-element for other VOP1
|
||||
return f"{name}_e64 {dst}, {s0}{os}{cl}{om}"
|
||||
# Native VOP3
|
||||
n = inst.num_srcs()
|
||||
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
|
||||
return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}"
|
||||
is3 = _has(name, 'fma', 'mad', 'min3', 'max3', 'med3', 'div_fix', 'div_fmas', 'sad', 'lerp', 'align', 'cube', 'bfe', 'bfi',
|
||||
'perm_b32', 'cndmask', 'xor3', 'or3', 'add3', 'lshl_or', 'and_or', 'lshl_add', 'add_lshl', 'xad', 'maxmin', 'minmax', 'dot2', 'cvt_pk_u8', 'mullit',
|
||||
'minimummaximum', 'maximumminimum', 'minimum3', 'maximum3')
|
||||
# permlane16/permlanex16 have 3 sources, but _var variants have 2
|
||||
if 'permlane' in name and 'var' not in name: is3 = True
|
||||
# Handle byte_sel for fp8/bf8 instructions (opsel encodes byte_sel, not op_sel)
|
||||
# For VOP1-encoded VOP3 (op < 512): cvt_f32_fp8, cvt_f32_bf8
|
||||
# For native VOP3: cvt_sr_fp8, cvt_sr_bf8
|
||||
if ('cvt_f32_fp8' in name or 'cvt_f32_bf8' in name or 'cvt_sr_fp8' in name or 'cvt_sr_bf8' in name) and 'pk' not in name:
|
||||
# For VOP1 encoding (op < 512), byte_sel is in bits[1:0] of opsel; for native VOP3, it's bits[3:2]
|
||||
byte_sel = (inst.opsel & 3) if inst.op < 512 else ((inst.opsel >> 2) & 3)
|
||||
os = f" byte_sel:{byte_sel}" if byte_sel else ""
|
||||
else:
|
||||
os = _opsel_str(inst.opsel, 3 if is3 else 2, need_opsel, is16_d)
|
||||
return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if is3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}"
|
||||
|
||||
def _disasm_vop3sd(inst: VOP3SD) -> str:
|
||||
name = inst.op_name.lower()
|
||||
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
|
||||
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
|
||||
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
|
||||
op, name = VOP3SDOp(inst.op), VOP3SDOp(inst.op).name.lower()
|
||||
is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32')
|
||||
def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s
|
||||
s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64)
|
||||
dst, is2src = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}", op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32)
|
||||
suffix = "_e64" if name.startswith('v_') and 'co_' in name else ""
|
||||
return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {srcs}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}"
|
||||
return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{'' if is2src else f', {s2}'}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}"
|
||||
|
||||
def _disasm_vopd(inst: VOPD) -> str:
|
||||
is_rdna4 = 'rdna4' in inst.__class__.__module__
|
||||
if is_rdna4:
|
||||
from extra.assembly.amd.autogen.rdna4.enum import VOPDOp as VOPDOp4
|
||||
OpEnum = VOPDOp4
|
||||
else:
|
||||
OpEnum = VOPDOp
|
||||
lit = inst._literal or inst.literal
|
||||
vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), VOPDOp(inst.opx).name.lower(), VOPDOp(inst.opy).name.lower()
|
||||
vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), OpEnum(inst.opx).name.lower(), OpEnum(inst.opy).name.lower()
|
||||
def half(n, vd, s0, vs1): return f"{n} v{vd}, {inst.lit(s0)}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" if 'mov' in n else f"{n} v{vd}, {inst.lit(s0)}, v{vs1}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}"
|
||||
return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}"
|
||||
|
||||
def _disasm_vop3p(inst: VOP3P) -> str:
|
||||
name = inst.op_name.lower()
|
||||
is_wmma, n, is_fma_mix = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name
|
||||
if is_wmma:
|
||||
sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8
|
||||
src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8)
|
||||
def _disasm_vop3p(inst: VOP3P, wave_size: int = 32) -> str:
|
||||
is_rdna4 = 'rdna4' in inst.__class__.__module__
|
||||
if is_rdna4:
|
||||
from extra.assembly.amd.autogen.rdna4.enum import VOP3POp as OpEnum
|
||||
else:
|
||||
OpEnum = VOP3POp
|
||||
name = OpEnum(inst.op).name.lower()
|
||||
is_wmma, is_swmmac = 'wmma' in name and 'swmmac' not in name, 'swmmac' in name
|
||||
is_3src, is_fma_mix = _has(name, 'fma', 'mad', 'dot', 'wmma'), 'fma_mix' in name
|
||||
# Wave64 uses half the register widths of wave32 for WMMA
|
||||
wave_div = 2 if wave_size == 64 else 1
|
||||
if is_swmmac and is_rdna4:
|
||||
# SWMMAC (sparse WMMA): src2 is a single VGPR index, not an accumulator
|
||||
# Determine src0/src1/dst sizes based on instruction type
|
||||
if 'f16' in name or 'bf16' in name:
|
||||
if 'f16_16x16x32_f16' in name or 'bf16_16x16x32_bf16' in name:
|
||||
s0c, s1c, dc = 4, 8, 4 # f16/bf16 output
|
||||
else:
|
||||
s0c, s1c, dc = 4, 8, 8 # f32 output
|
||||
elif 'iu8' in name: s0c, s1c, dc = 2, 4, 8
|
||||
elif 'iu4' in name:
|
||||
if '16x16x64' in name: s0c, s1c, dc = 2, 4, 8
|
||||
else: s0c, s1c, dc = 1, 2, 8
|
||||
elif 'fp8' in name or 'bf8' in name: s0c, s1c, dc = 2, 4, 8
|
||||
else: s0c, s1c, dc = 4, 8, 8
|
||||
s0c, s1c, dc = max(1, s0c // wave_div), max(1, s1c // wave_div), max(1, dc // wave_div)
|
||||
src0, src1, src2, dst = _fmt_src(inst.src0, s0c), _fmt_src(inst.src1, s1c), _fmt_src(inst.src2, 1), _vreg(inst.vdst, dc)
|
||||
elif is_wmma:
|
||||
# RDNA4 WMMA uses smaller source register widths than RDNA3
|
||||
if is_rdna4:
|
||||
# RDNA4 wave32 source widths: iu4->1/2, iu8->2, fp8/bf8->2, f16/bf16->4
|
||||
if 'iu4' in name:
|
||||
sc = 2 if '16x16x32' in name else 1
|
||||
elif 'iu8' in name or 'fp8' in name or 'bf8' in name: sc = 2
|
||||
else: sc = 4 # f16/bf16
|
||||
# Destination width: f16/bf16 output->4, f32/i32 output->8
|
||||
dc = 4 if name.startswith('v_wmma_f16') or name.startswith('v_wmma_bf16') else 8
|
||||
else:
|
||||
# RDNA3: iu4->2, iu8->4, f16/bf16->8
|
||||
sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8
|
||||
dc = 8
|
||||
sc, dc = max(1, sc // wave_div), max(1, dc // wave_div)
|
||||
src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, dc), _vreg(inst.vdst, dc)
|
||||
else: src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}"
|
||||
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
|
||||
if is_fma_mix:
|
||||
n, opsel_hi = 3 if is_3src else 2, inst.opsel_hi | (inst.opsel_hi2 << 2)
|
||||
if is_swmmac and is_rdna4:
|
||||
# SWMMAC uses index_key instead of op_sel; opsel bits encode the key value
|
||||
mods = ([f"index_key:{inst.opsel & 7}"] if inst.opsel else []) + \
|
||||
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
|
||||
elif is_fma_mix:
|
||||
def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s)
|
||||
src0, src1, src2 = m(src0, inst.neg & 1, inst.neg_hi & 1), m(src1, inst.neg & 2, inst.neg_hi & 2), m(src2, inst.neg & 4, inst.neg_hi & 4)
|
||||
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if inst.clmp else [])
|
||||
else:
|
||||
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \
|
||||
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if is_3src else 3) else []) + \
|
||||
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
|
||||
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
|
||||
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if is_3src else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
|
||||
|
||||
def _disasm_buf(inst: MUBUF | MTBUF) -> str:
|
||||
name = inst.op_name.lower()
|
||||
if inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
|
||||
op = MTBUFOp(inst.op) if isinstance(inst, MTBUF) else MUBUFOp(inst.op)
|
||||
name = op.name.lower()
|
||||
if op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
|
||||
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
|
||||
((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \
|
||||
{'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1)
|
||||
|
|
@ -318,7 +482,7 @@ def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
|
|||
return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked
|
||||
|
||||
def _disasm_mimg(inst: MIMG) -> str:
|
||||
name = inst.op_name.lower()
|
||||
name = MIMGOp(inst.op).name.lower()
|
||||
srsrc_base = inst.srsrc * 4
|
||||
srsrc_str = _sreg_or_ttmp(srsrc_base, 8)
|
||||
# BVH intersect ray: special case with 4 SGPR srsrc
|
||||
|
|
@ -334,8 +498,8 @@ def _disasm_mimg(inst: MIMG) -> str:
|
|||
dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"dim_{inst.dim}"
|
||||
vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16)
|
||||
vaddr_str = f"v{inst.vaddr}" if vaddr == 1 else _vreg(inst.vaddr, vaddr)
|
||||
# modifiers
|
||||
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask and (inst.dmask != 15 or 'atomic' in name) else []
|
||||
# modifiers - always include dmask for image load/store/atomic (LLVM uses it for vdata size validation)
|
||||
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask else []
|
||||
mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}")
|
||||
for flag, mod in [(inst.unrm,"unorm"),(inst.glc,"glc"),(inst.slc,"slc"),(inst.dlc,"dlc"),(inst.r128,"r128"),
|
||||
(inst.a16,"a16"),(inst.tfe,"tfe"),(inst.lwe,"lwe"),(inst.d16,"d16")]:
|
||||
|
|
@ -346,38 +510,329 @@ def _disasm_mimg(inst: MIMG) -> str:
|
|||
ssamp_str = ", " + _sreg_or_ttmp(inst.ssamp * 4, 4)
|
||||
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}"
|
||||
|
||||
def _sop_widths(name: str) -> tuple[int, int, int]:
|
||||
"""Return (dst_width, src0_width, src1_width) in register count for SOP instructions."""
|
||||
if name in ('s_bitset0_b64', 's_bitset1_b64', 's_bfm_b64'): return 2, 1, 1
|
||||
if name in ('s_lshl_b64', 's_lshr_b64', 's_ashr_i64', 's_bfe_u64', 's_bfe_i64'): return 2, 2, 1
|
||||
if name in ('s_bitcmp0_b64', 's_bitcmp1_b64'): return 1, 2, 1
|
||||
if m := re.search(r'_(b|i|u)(32|64)_(b|i|u)(32|64)$', name): return 2 if m.group(2) == '64' else 1, 2 if m.group(4) == '64' else 1, 1
|
||||
if m := re.search(r'_(b|i|u)(32|64)$', name): sz = 2 if m.group(2) == '64' else 1; return sz, sz, sz
|
||||
return 1, 1, 1
|
||||
|
||||
def _disasm_sop1(inst: SOP1) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
op, name = SOP1Op(inst.op), SOP1Op(inst.op).name.lower()
|
||||
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
|
||||
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {_fmt_src(inst.ssrc0, 2)}"
|
||||
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {_fmt_src(inst.ssrc0, 2)}"
|
||||
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.src_regs(0) == 1 else _fmt_src(inst.ssrc0, inst.src_regs(0))}"
|
||||
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, 2 if 'b64' in name else 1)}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
|
||||
dn, s0n, _ = _sop_widths(name)
|
||||
return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if s0n == 1 else _fmt_src(inst.ssrc0, s0n)}"
|
||||
|
||||
def _disasm_sop2(inst: SOP2) -> str:
|
||||
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}"
|
||||
name = SOP2Op(inst.op).name.lower()
|
||||
dn, s0n, s1n = _sop_widths(name)
|
||||
return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, s0n)}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, s1n)}"
|
||||
|
||||
def _disasm_sopc(inst: SOPC) -> str:
|
||||
return f"{inst.op_name.lower()} {_fmt_src(inst.ssrc0, inst.src_regs(0))}, {_fmt_src(inst.ssrc1, inst.src_regs(1))}"
|
||||
name = SOPCOp(inst.op).name.lower()
|
||||
_, s0n, s1n = _sop_widths(name)
|
||||
return f"{name} {_fmt_src(inst.ssrc0, s0n)}, {_fmt_src(inst.ssrc1, s1n)}"
|
||||
|
||||
def _disasm_sopk(inst: SOPK) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
if op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}"
|
||||
if op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32):
|
||||
# Use architecture-specific SOPK op enum
|
||||
if 'rdna4' in inst.__class__.__module__:
|
||||
from extra.assembly.amd.autogen.rdna4.enum import SOPKOp as OpEnum
|
||||
hwreg_map = HWREG_GFX12
|
||||
else:
|
||||
OpEnum = SOPKOp
|
||||
hwreg_map = HWREG
|
||||
op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower()
|
||||
if name == 's_version': return f"{name} 0x{inst.simm16:x}"
|
||||
if name in ('s_setreg_b32', 's_getreg_b32'):
|
||||
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
|
||||
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
|
||||
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if op == SOPKOp.S_SETREG_B32 else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}"
|
||||
hreg_name = hwreg_map.get(hid, str(hid))
|
||||
# If offset=0 and size=32, use short form hwreg(NAME), otherwise hwreg(NAME, off, sz)
|
||||
if hid in (16, 17): hs = f"0x{inst.simm16:x}"
|
||||
elif hoff == 0 and hsz == 32: hs = f"hwreg({hreg_name})"
|
||||
else: hs = f"hwreg({hreg_name}, {hoff}, {hsz})"
|
||||
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if name == 's_setreg_b32' else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
|
||||
dn, _, _ = _sop_widths(name)
|
||||
return f"{name} {_fmt_sdst(inst.sdst, dn)}, 0x{inst.simm16:x}"
|
||||
|
||||
def _disasm_vinterp(inst: VINTERP) -> str:
|
||||
name = VINTERPOp(inst.op).name.lower()
|
||||
src0 = f"-{inst.lit(inst.src0)}" if inst.neg & 1 else inst.lit(inst.src0)
|
||||
src1 = f"-{inst.lit(inst.src1)}" if inst.neg & 2 else inst.lit(inst.src1)
|
||||
src2 = f"-{inst.lit(inst.src2)}" if inst.neg & 4 else inst.lit(inst.src2)
|
||||
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
|
||||
return f"{inst.op_name.lower()} v{inst.vdst}, {inst.lit(inst.src0, inst.neg & 1)}, {inst.lit(inst.src1, inst.neg & 2)}, {inst.lit(inst.src2, inst.neg & 4)}" + (" " + mods if mods else "")
|
||||
return f"{name} v{inst.vdst}, {src0}, {src1}, {src2}" + (" " + mods if mods else "")
|
||||
|
||||
# Export targets: mrt0-7, mrtz, pos0-4, prim, dual_src_blend0/1
|
||||
_EXP_TARGETS = {**{i: f'mrt{i}' for i in range(8)}, 8: 'mrtz', **{i+12: f'pos{i}' for i in range(5)}, 20: 'prim', 21: 'dual_src_blend0', 22: 'dual_src_blend1'}
|
||||
|
||||
def _disasm_exp(inst) -> str:
|
||||
target = _EXP_TARGETS.get(inst.target, f"invalid_target_{inst.target}")
|
||||
en = inst.en
|
||||
vsrc = lambda i, v: f"v{v}" if (en >> i) & 1 else "off"
|
||||
srcs = f"{vsrc(0, inst.vsrc0)}, {vsrc(1, inst.vsrc1)}, {vsrc(2, inst.vsrc2)}, {vsrc(3, inst.vsrc3)}"
|
||||
mods = _mods((inst.done, "done"), (inst.row, "row_en"))
|
||||
prefix = "export" if 'rdna4' in inst.__class__.__module__ else "exp"
|
||||
return f"{prefix} {target} {srcs}" + (" " + mods if mods else "")
|
||||
|
||||
def _disasm_ldsdir(inst) -> str:
|
||||
is_rdna4 = 'rdna4' in inst.__class__.__module__
|
||||
if is_rdna4:
|
||||
# RDNA4 uses ds_* prefix and wait_va_vdst/wait_vm_vsrc modifiers
|
||||
wait = f" wait_va_vdst:{inst.wait_va} wait_vm_vsrc:{inst.wait_vm}"
|
||||
if inst.op == 1: return f"ds_direct_load v{inst.vdst}{wait}"
|
||||
if inst.op == 0: return f"ds_param_load v{inst.vdst}, attr{inst.attr}.{['x','y','z','w'][inst.attr_chan]}{wait}"
|
||||
else:
|
||||
# RDNA3 uses lds_* prefix and wait_vdst modifier
|
||||
wait = f" wait_vdst:{inst.wait_va}" if inst.wait_va != 0 else ""
|
||||
if inst.op == 1: return f"lds_direct_load v{inst.vdst}{wait}"
|
||||
if inst.op == 0: return f"lds_param_load v{inst.vdst}, attr{inst.attr}.{['x','y','z','w'][inst.attr_chan]}{wait}"
|
||||
raise ValueError(f"unknown LDSDIR op: {inst.op}")
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# RDNA4-specific disassemblers (GFX12)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# th values for RDNA4 memory instructions (based on AMDGPU ISA docs and LLVM SIDefines.h)
|
||||
# Load: 0=RT(default), 1=NT, 2=HT, 3=LU, 4=NT_RT, 5=RT_NT, 6=NT_HT, 7=BYPASS(only with scope=SYS)
|
||||
_TH_LOAD = {0: '', 1: 'th:TH_LOAD_NT', 2: 'th:TH_LOAD_HT', 3: 'th:TH_LOAD_LU', 4: 'th:TH_LOAD_NT_RT', 5: 'th:TH_LOAD_RT_NT', 6: 'th:TH_LOAD_NT_HT', 7: 'th:TH_LOAD_BYPASS'}
|
||||
# Store: 0=RT(default), 1=NT, 2=HT, 3=WB, 4=NT_RT, 5=RT_NT, 6=NT_HT, 7=NT_WB (BYPASS is th=7 + scope=SYS)
|
||||
_TH_STORE = {0: '', 1: 'th:TH_STORE_NT', 2: 'th:TH_STORE_HT', 3: 'th:TH_STORE_WB', 4: 'th:TH_STORE_NT_RT', 5: 'th:TH_STORE_RT_NT', 6: 'th:TH_STORE_NT_HT', 7: 'th:TH_STORE_NT_WB'}
|
||||
# Atomic: bit0=RETURN, bit1=NT, bit2=CASCADE -> 0=none, 1=RETURN, 2=NT, 3=NT_RETURN, 4=CASCADE_RT, 5=RT_RETURN(N/A), 6=CASCADE_NT, 7=N/A
|
||||
_TH_ATOMIC = {0: '', 1: 'th:TH_ATOMIC_RETURN', 2: 'th:TH_ATOMIC_NT', 3: 'th:TH_ATOMIC_NT_RETURN', 4: 'th:TH_ATOMIC_CASCADE_RT', 5: 'th:TH_ATOMIC_RT_RETURN', 6: 'th:TH_ATOMIC_CASCADE_NT', 7: 'th:TH_ATOMIC_CASCADE_NT'}
|
||||
_SCOPE = {0: '', 1: 'scope:SCOPE_SE', 2: 'scope:SCOPE_DEV', 3: 'scope:SCOPE_SYS'}
|
||||
|
||||
def _rdna4_mem_mods(th: int, scope: int, is_store: bool, is_atomic: bool) -> str:
|
||||
th_map = _TH_ATOMIC if is_atomic else _TH_STORE if is_store else _TH_LOAD
|
||||
# Special case: th=3 with scope=SYS means BYPASS for load/store (otherwise th=3 means LU/WB)
|
||||
if th == 3 and scope == 3 and not is_atomic:
|
||||
th_s = 'th:TH_STORE_BYPASS' if is_store else 'th:TH_LOAD_BYPASS'
|
||||
else:
|
||||
th_s = th_map.get(th, f'th:{th}' if th else '')
|
||||
scope_s = _SCOPE.get(scope, f'scope:{scope}' if scope else '')
|
||||
return ' '.join(x for x in [th_s, scope_s] if x)
|
||||
|
||||
def _disasm_vflat(inst) -> str:
|
||||
"""Disassemble RDNA4 VFLAT/VGLOBAL/VSCRATCH instructions."""
|
||||
from extra.assembly.amd.autogen.rdna4.enum import VFLATOp, VGLOBALOp, VSCRATCHOp
|
||||
cls_name = type(inst).__name__
|
||||
if cls_name == 'VGLOBAL': op_enum, prefix = VGLOBALOp, 'global'
|
||||
elif cls_name == 'VSCRATCH': op_enum, prefix = VSCRATCHOp, 'scratch'
|
||||
else: op_enum, prefix = VFLATOp, 'flat'
|
||||
name = op_enum(inst.op).name.lower()
|
||||
|
||||
# global_wb, global_wbinv, global_inv are cache control instructions with no operands
|
||||
if name in ('global_wb', 'global_wbinv', 'global_inv'):
|
||||
mods = _rdna4_mem_mods(inst.th, inst.scope, False, False)
|
||||
return f"{name}" + (f" {mods}" if mods else "")
|
||||
|
||||
# addtid instructions use thread ID as address offset, no vaddr operand
|
||||
is_addtid = 'addtid' in name
|
||||
|
||||
# Data width based on instruction name suffix
|
||||
suffix = name.split('_')[-1]
|
||||
# block loads/stores use 32 VGPRs
|
||||
if 'block' in name:
|
||||
base_w = 32
|
||||
else:
|
||||
base_w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'f32':1,'f64':2}.get(suffix, 1)
|
||||
# For cmpswap: vsrc holds cmp+data pairs (2x base), vdst is base width
|
||||
vsrc_w = base_w * 2 if 'cmpswap' in name else base_w
|
||||
vdst_w = base_w
|
||||
|
||||
# Offset: signed 24-bit (stored as unsigned, needs sign extension)
|
||||
off = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000
|
||||
off_s = f" offset:{off}" if off else ""
|
||||
|
||||
# Memory modifiers
|
||||
is_store, is_atomic = 'store' in name, 'atomic' in name
|
||||
mods = _rdna4_mem_mods(inst.th, inst.scope, is_store, is_atomic)
|
||||
|
||||
# saddr handling - VGLOBAL and VSCRATCH need explicit "off" when saddr=124
|
||||
if inst.saddr == 124: saddr_s = "off" if prefix in ('global', 'scratch') else ""
|
||||
elif inst.saddr in SPECIAL_PAIRS: saddr_s = SPECIAL_PAIRS[inst.saddr]
|
||||
else: saddr_s = _sreg(inst.saddr, 2) if prefix == 'global' else decode_src(inst.saddr)
|
||||
|
||||
# Address width: 1 for scratch with saddr, 2 otherwise
|
||||
addr_w = 1 if (prefix == 'scratch' or (inst.saddr != 124 and prefix != 'flat')) else 2
|
||||
vaddr_s = _vreg(inst.vaddr, addr_w)
|
||||
vsrc_s = _vreg(inst.vsrc, vsrc_w)
|
||||
vdst_s = _vreg(inst.vdst, vdst_w)
|
||||
|
||||
# addtid instructions don't have vaddr, just vdata and saddr
|
||||
if is_addtid:
|
||||
if is_store: return f"{name} {vsrc_s}, {saddr_s}{off_s}" + (f" {mods}" if mods else "")
|
||||
return f"{name} {vdst_s}, {saddr_s}{off_s}" + (f" {mods}" if mods else "")
|
||||
|
||||
# Regular instructions need comma before saddr
|
||||
saddr_s = f", {saddr_s}" if saddr_s else ""
|
||||
|
||||
if is_atomic:
|
||||
if inst.th == 1: # TH_ATOMIC_RETURN
|
||||
return f"{name} {vdst_s}, {vaddr_s}, {vsrc_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "")
|
||||
return f"{name} {vaddr_s}, {vsrc_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "")
|
||||
if is_store: return f"{name} {vaddr_s}, {vsrc_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "")
|
||||
return f"{name} {vdst_s}, {vaddr_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "")
|
||||
|
||||
def _disasm_vbuffer(inst) -> str:
|
||||
"""Disassemble RDNA4 VBUFFER instructions."""
|
||||
from extra.assembly.amd.autogen.rdna4.enum import VBUFFEROp
|
||||
name = VBUFFEROp(inst.op).name.lower()
|
||||
|
||||
# Determine if this is a typed buffer instruction (MTBUF format)
|
||||
is_format = 'format' in name
|
||||
|
||||
# Data width based on instruction name
|
||||
if is_format:
|
||||
w = {'x': 1, 'xy': 2, 'xyz': 3, 'xyzw': 4}.get(name.split('_')[-1], 1)
|
||||
if 'd16' in name: w = (w + 1) // 2
|
||||
else:
|
||||
suffix = name.split('_')[-1]
|
||||
w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'b8':1,'b16':1,'f16':1,'f32':1,'bf16':1}.get(suffix, 1)
|
||||
if 'cmpswap' in name: w *= 2
|
||||
if inst.tfe: w += 1
|
||||
|
||||
vdata_s = _vreg(inst.vdata, w)
|
||||
vaddr_s = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else (f"v{inst.vaddr}" if inst.offen or inst.idxen else "off")
|
||||
# RDNA4 VBUFFER rsrc field stores the SGPR index directly (not /4 like RDNA3)
|
||||
srsrc_s = _sreg_or_ttmp(inst.rsrc, 4)
|
||||
soffset_s = decode_src(inst.soffset)
|
||||
|
||||
off = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000
|
||||
is_store, is_atomic = 'store' in name, 'atomic' in name
|
||||
mods = _rdna4_mem_mods(inst.th, inst.scope, is_store, is_atomic)
|
||||
|
||||
# Format field is only for MTBUF (tbuffer_*) instructions, not buffer_*_format_* instructions
|
||||
# We don't output format for buffer instructions since they use implicit format
|
||||
parts = []
|
||||
if inst.idxen: parts.append("idxen")
|
||||
if inst.offen: parts.append("offen")
|
||||
if off: parts.append(f"offset:{off}")
|
||||
if mods: parts.append(mods)
|
||||
if inst.tfe: parts.append("tfe")
|
||||
|
||||
return f"{name} {vdata_s}, {vaddr_s}, {srsrc_s}, {soffset_s}" + (f" {' '.join(parts)}" if parts else "")
|
||||
|
||||
def _disasm_vimage(inst) -> str:
|
||||
"""Disassemble RDNA4 VIMAGE instructions."""
|
||||
from extra.assembly.amd.autogen.rdna4.enum import VIMAGEOp
|
||||
name = VIMAGEOp(inst.op).name.lower()
|
||||
|
||||
# RDNA4 VIMAGE rsrc field stores the SGPR index directly (not /4 like RDNA3)
|
||||
if 'bvh' in name:
|
||||
# BVH intersect ray: special format with individual/range vaddr components
|
||||
# Format: [node_ptr, ray_extent, ray_origin(3), ray_dir(3), ray_inv_dir(3)]
|
||||
# bvh64 has 2-VGPR node_ptr, a16 removes ray_inv_dir
|
||||
if 'dual' in name or 'bvh8' in name:
|
||||
# dual/bvh8: [node_ptr(2), ray_extent(2), ray_origin(3), ray_dir(3), ...]
|
||||
parts = [_vreg(inst.vaddr0, 2), _vreg(inst.vaddr1, 2), _vreg(inst.vaddr2, 3), _vreg(inst.vaddr3, 3)]
|
||||
if not inst.a16: parts.append(_vreg(inst.vaddr4, 1 if 'bvh8' in name else 2))
|
||||
dst_w = 10
|
||||
elif '64' in name:
|
||||
parts = [_vreg(inst.vaddr0, 2), f"v{inst.vaddr1}", _vreg(inst.vaddr2, 3), _vreg(inst.vaddr3, 3)]
|
||||
if not inst.a16: parts.append(_vreg(inst.vaddr4, 3))
|
||||
dst_w = 4
|
||||
else:
|
||||
parts = [f"v{inst.vaddr0}", f"v{inst.vaddr1}", _vreg(inst.vaddr2, 3), _vreg(inst.vaddr3, 3)]
|
||||
if not inst.a16: parts.append(_vreg(inst.vaddr4, 3))
|
||||
dst_w = 4
|
||||
return f"{name} {_vreg(inst.vdata, dst_w)}, [{', '.join(parts)}], {_sreg_or_ttmp(inst.rsrc, 4)}{' a16' if inst.a16 else ''}"
|
||||
|
||||
# vdata width - msaa_load always uses 4 VGPRs per channel
|
||||
vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1)
|
||||
if inst.d16: vdata = (vdata + 1) // 2
|
||||
if inst.tfe: vdata += 1
|
||||
|
||||
# dim names
|
||||
dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
|
||||
dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"{inst.dim}"
|
||||
|
||||
# vaddr width calculation (RDNA4 uses vaddr0-4 for address components)
|
||||
vaddr_w = _mimg_vaddr_width(name, inst.dim, inst.a16)
|
||||
if vaddr_w == 1: vaddr_s = f"v{inst.vaddr0}"
|
||||
elif vaddr_w == 2: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}]"
|
||||
elif vaddr_w == 3: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}]"
|
||||
elif vaddr_w == 4: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, v{inst.vaddr3}]"
|
||||
else: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, v{inst.vaddr3}, v{inst.vaddr4}]"
|
||||
|
||||
srsrc_s = _sreg_or_ttmp(inst.rsrc, 8)
|
||||
# RDNA4 always requires dmask for size calculation (even if 0xf)
|
||||
mods = [f"dmask:0x{inst.dmask:x}"]
|
||||
mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}")
|
||||
# Add th/scope before other modifiers, then r128, then a16/tfe/d16 (LLVM expects this order)
|
||||
if inst.th or inst.scope:
|
||||
is_store, is_atomic = 'store' in name, 'atomic' in name
|
||||
mem_mods = _rdna4_mem_mods(inst.th, inst.scope, is_store, is_atomic)
|
||||
if mem_mods: mods.append(mem_mods)
|
||||
if inst.r128: mods.append("r128")
|
||||
mods.extend([m for c, m in [(inst.a16, "a16"), (inst.tfe, "tfe"), (inst.d16, "d16")] if c])
|
||||
|
||||
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_s}, {srsrc_s} {' '.join(mods)}"
|
||||
|
||||
def _disasm_vsample(inst) -> str:
|
||||
"""Disassemble RDNA4 VSAMPLE instructions."""
|
||||
from extra.assembly.amd.autogen.rdna4.enum import VSAMPLEOp
|
||||
name = VSAMPLEOp(inst.op).name.lower()
|
||||
|
||||
# vdata width
|
||||
vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1)
|
||||
if inst.d16: vdata = (vdata + 1) // 2
|
||||
if inst.tfe: vdata += 1
|
||||
|
||||
dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
|
||||
dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"{inst.dim}"
|
||||
|
||||
vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16)
|
||||
if vaddr == 1: vaddr_s = f"v{inst.vaddr0}"
|
||||
elif vaddr == 2: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}]"
|
||||
elif vaddr == 3: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}]"
|
||||
elif vaddr == 4: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, v{inst.vaddr3}]"
|
||||
else:
|
||||
# More than 4 vaddrs: vaddr3 becomes start of a contiguous range for the remaining coords
|
||||
extra = vaddr - 3 # vaddr0-2 are individual, vaddr3 starts range of remaining
|
||||
vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, {_vreg(inst.vaddr3, extra)}]"
|
||||
|
||||
# RDNA4 VSAMPLE rsrc/samp fields store the SGPR index directly (not /4 like RDNA3)
|
||||
srsrc_s = _sreg_or_ttmp(inst.rsrc, 8)
|
||||
|
||||
# msaa_load doesn't use a sampler (it's a load, not a sample), but gather4h does
|
||||
uses_sampler = 'msaa_load' not in name
|
||||
ssamp_s = f", {_sreg_or_ttmp(inst.samp, 4)}" if uses_sampler else ""
|
||||
|
||||
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask else []
|
||||
mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}")
|
||||
if inst.unrm: mods.append("unorm")
|
||||
# th/scope must come before r128, a16, tfe, lwe, d16
|
||||
if inst.th or inst.scope:
|
||||
mem_mods = _rdna4_mem_mods(inst.th, inst.scope, False, False)
|
||||
if mem_mods: mods.append(mem_mods)
|
||||
mods.extend([m for c, m in [(inst.r128, "r128"), (inst.a16, "a16"), (inst.tfe, "tfe"), (inst.lwe, "lwe"), (inst.d16, "d16")] if c])
|
||||
|
||||
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_s}, {srsrc_s}{ssamp_s} {' '.join(mods)}"
|
||||
|
||||
DISASM_HANDLERS = {VOP1: _disasm_vop1, VOP2: _disasm_vop2, VOPC: _disasm_vopc, VOP3: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOPD: _disasm_vopd, VOP3P: _disasm_vop3p,
|
||||
VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, MUBUF: _disasm_buf, MTBUF: _disasm_buf,
|
||||
MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk}
|
||||
MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk, EXP: _disasm_exp, LDSDIR: _disasm_ldsdir}
|
||||
|
||||
def disasm(inst: Inst) -> str: return DISASM_HANDLERS[type(inst)](inst)
|
||||
# RDNA4 uses different class names, dispatch by name for cross-arch support
|
||||
_DISASM_BY_NAME = {'VOP1': _disasm_vop1, 'VOP2': _disasm_vop2, 'VOPC': _disasm_vopc, 'VOP3': _disasm_vop3, 'VOP3SD': _disasm_vop3sd,
|
||||
'VOPD': _disasm_vopd, 'VOP3P': _disasm_vop3p, 'VINTERP': _disasm_vinterp, 'SOPP': _disasm_sopp, 'SMEM': _disasm_smem,
|
||||
'DS': _disasm_ds, 'VDS': _disasm_ds, 'FLAT': _disasm_flat, 'MUBUF': _disasm_buf, 'MTBUF': _disasm_buf, 'MIMG': _disasm_mimg,
|
||||
'SOP1': _disasm_sop1, 'SOP2': _disasm_sop2, 'SOPC': _disasm_sopc, 'SOPK': _disasm_sopk,
|
||||
'EXP': _disasm_exp, 'LDSDIR': _disasm_ldsdir, 'VEXPORT': _disasm_exp, 'VDSDIR': _disasm_ldsdir,
|
||||
'VFLAT': _disasm_vflat, 'VGLOBAL': _disasm_vflat, 'VSCRATCH': _disasm_vflat,
|
||||
'VBUFFER': _disasm_vbuffer, 'VIMAGE': _disasm_vimage, 'VSAMPLE': _disasm_vsample}
|
||||
|
||||
def disasm(inst: Inst, wave_size: int = 32) -> str:
|
||||
handler = DISASM_HANDLERS.get(type(inst)) or _DISASM_BY_NAME.get(type(inst).__name__)
|
||||
if handler is None: raise KeyError(f"no disasm handler for {type(inst).__name__}")
|
||||
# For VOP3P (includes WMMA), pass wave_size if handler supports it
|
||||
if handler == _disasm_vop3p: return _disasm_vop3p(inst, wave_size)
|
||||
return handler(inst)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# ASSEMBLER
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# autogenerated from AMD RDNA4 ISA PDF by pdf.py - do not edit
|
||||
# ruff: noqa: F401,F403
|
||||
from typing import Annotated
|
||||
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField
|
||||
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, Inst96, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField
|
||||
from extra.assembly.amd.autogen.rdna4.enum import *
|
||||
import functools
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ class SOPP(Inst32):
|
|||
op:Annotated[BitField, SOPPOp] = bits[22:16]
|
||||
simm16:SImm = bits[15:0]
|
||||
|
||||
class VBUFFER(Inst64):
|
||||
class VBUFFER(Inst96):
|
||||
encoding = bits[31:26] == 0b110001
|
||||
soffset:SSrc = bits[6:0]
|
||||
op:Annotated[BitField, VBUFFEROp] = bits[21:14]
|
||||
|
|
@ -94,17 +94,14 @@ class VDS(Inst64):
|
|||
data1:VGPRField = bits[55:48]
|
||||
vdst:VGPRField = bits[63:56]
|
||||
|
||||
class VDSDIR(Inst64):
|
||||
encoding = bits[31:24] == 0b11001101
|
||||
class VDSDIR(Inst32):
|
||||
encoding = bits[31:24] == 0b11001110
|
||||
vdst:VGPRField = bits[7:0]
|
||||
waitexp = bits[10:8]
|
||||
opsel = bits[14:11]
|
||||
cm = bits[15]
|
||||
op:Annotated[BitField, VDSDIROp] = bits[20:16]
|
||||
src0:Src = bits[40:32]
|
||||
src1:Src = bits[49:41]
|
||||
src2:Src = bits[58:50]
|
||||
neg = bits[63:61]
|
||||
attr_chan = bits[9:8]
|
||||
attr = bits[15:10]
|
||||
wait_va = bits[19:16]
|
||||
op:Annotated[BitField, VDSDIROp] = bits[21:20]
|
||||
wait_vm = bits[23] # single bit, bit 22 is reserved
|
||||
|
||||
class VEXPORT(Inst64):
|
||||
encoding = bits[31:26] == 0b111110
|
||||
|
|
@ -117,6 +114,49 @@ class VEXPORT(Inst64):
|
|||
vsrc2 = bits[55:48]
|
||||
vsrc3 = bits[63:56]
|
||||
|
||||
class VFLAT(Inst96):
|
||||
encoding = bits[31:24] == 0b11101100
|
||||
saddr:SSrc = bits[6:0]
|
||||
op:Annotated[BitField, VFLATOp] = bits[20:14]
|
||||
vdst:VGPRField = bits[39:32]
|
||||
sve = bits[49]
|
||||
scope = bits[51:50]
|
||||
th = bits[54:52]
|
||||
vsrc = bits[62:55]
|
||||
vaddr:VGPRField = bits[71:64]
|
||||
ioffset = bits[95:72]
|
||||
|
||||
class VGLOBAL(Inst96):
|
||||
encoding = bits[31:24] == 0b11101110
|
||||
saddr:SSrc = bits[6:0]
|
||||
op:Annotated[BitField, VGLOBALOp] = bits[20:14]
|
||||
vdst:VGPRField = bits[39:32]
|
||||
sve = bits[49]
|
||||
scope = bits[51:50]
|
||||
th = bits[54:52]
|
||||
vsrc = bits[62:55]
|
||||
vaddr:VGPRField = bits[71:64]
|
||||
ioffset = bits[95:72]
|
||||
|
||||
class VIMAGE(Inst96):
|
||||
encoding = bits[31:26] == 0b110100
|
||||
dim = bits[2:0]
|
||||
r128 = bits[4]
|
||||
d16 = bits[5]
|
||||
a16 = bits[6]
|
||||
op:Annotated[BitField, VIMAGEOp] = bits[21:14]
|
||||
dmask = bits[25:22]
|
||||
vdata:VGPRField = bits[39:32]
|
||||
rsrc = bits[49:41]
|
||||
scope = bits[51:50]
|
||||
th = bits[54:52]
|
||||
tfe = bits[55]
|
||||
vaddr4 = bits[63:56]
|
||||
vaddr0 = bits[71:64]
|
||||
vaddr1 = bits[79:72]
|
||||
vaddr2 = bits[87:80]
|
||||
vaddr3 = bits[95:88]
|
||||
|
||||
class VINTERP(Inst64):
|
||||
encoding = bits[31:24] == 0b11001101
|
||||
op:Annotated[BitField, VINTERPOp] = bits[20:16]
|
||||
|
|
@ -199,6 +239,39 @@ class VOPD(Inst64):
|
|||
srcy0:Src = bits[40:32]
|
||||
vsrcy1:VGPRField = bits[48:41]
|
||||
|
||||
class VSAMPLE(Inst96):
|
||||
encoding = bits[31:26] == 0b111001
|
||||
dim = bits[2:0]
|
||||
tfe = bits[3]
|
||||
r128 = bits[4]
|
||||
d16 = bits[5]
|
||||
a16 = bits[6]
|
||||
unrm = bits[13]
|
||||
op:Annotated[BitField, VSAMPLEOp] = bits[21:14]
|
||||
dmask = bits[25:22]
|
||||
vdata:VGPRField = bits[39:32]
|
||||
lwe = bits[40]
|
||||
rsrc = bits[49:41]
|
||||
scope = bits[51:50]
|
||||
th = bits[54:52]
|
||||
samp = bits[63:55]
|
||||
vaddr0 = bits[71:64]
|
||||
vaddr1 = bits[79:72]
|
||||
vaddr2 = bits[87:80]
|
||||
vaddr3 = bits[95:88]
|
||||
|
||||
class VSCRATCH(Inst96):
|
||||
encoding = bits[31:24] == 0b11101101
|
||||
saddr:SSrc = bits[6:0]
|
||||
op:Annotated[BitField, VSCRATCHOp] = bits[20:14]
|
||||
vdst:VGPRField = bits[39:32]
|
||||
sve = bits[49]
|
||||
scope = bits[51:50]
|
||||
th = bits[54:52]
|
||||
vsrc = bits[62:55]
|
||||
vaddr:VGPRField = bits[71:64]
|
||||
ioffset = bits[95:72]
|
||||
|
||||
# instruction helpers
|
||||
s_load_b32 = functools.partial(SMEM, SMEMOp.S_LOAD_B32)
|
||||
s_load_b64 = functools.partial(SMEM, SMEMOp.S_LOAD_B64)
|
||||
|
|
@ -571,6 +644,159 @@ tbuffer_store_d16_format_xyz = functools.partial(VBUFFER, VBUFFEROp.TBUFFER_STOR
|
|||
tbuffer_store_d16_format_xyzw = functools.partial(VBUFFER, VBUFFEROp.TBUFFER_STORE_D16_FORMAT_XYZW)
|
||||
ds_param_load = functools.partial(VDSDIR, VDSDIROp.DS_PARAM_LOAD)
|
||||
ds_direct_load = functools.partial(VDSDIR, VDSDIROp.DS_DIRECT_LOAD)
|
||||
flat_load_u8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_U8)
|
||||
flat_load_i8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_I8)
|
||||
flat_load_u16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_U16)
|
||||
flat_load_i16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_I16)
|
||||
flat_load_b32 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B32)
|
||||
flat_load_b64 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B64)
|
||||
flat_load_b96 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B96)
|
||||
flat_load_b128 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B128)
|
||||
flat_store_b8 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B8)
|
||||
flat_store_b16 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B16)
|
||||
flat_store_b32 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B32)
|
||||
flat_store_b64 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B64)
|
||||
flat_store_b96 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B96)
|
||||
flat_store_b128 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B128)
|
||||
flat_load_d16_u8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_U8)
|
||||
flat_load_d16_i8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_I8)
|
||||
flat_load_d16_b16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_B16)
|
||||
flat_load_d16_hi_u8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_U8)
|
||||
flat_load_d16_hi_i8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_I8)
|
||||
flat_load_d16_hi_b16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_B16)
|
||||
flat_store_d16_hi_b8 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_D16_HI_B8)
|
||||
flat_store_d16_hi_b16 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_D16_HI_B16)
|
||||
flat_atomic_swap_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SWAP_B32)
|
||||
flat_atomic_cmpswap_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_CMPSWAP_B32)
|
||||
flat_atomic_add_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_ADD_U32)
|
||||
flat_atomic_sub_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SUB_U32)
|
||||
flat_atomic_sub_clamp_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SUB_CLAMP_U32)
|
||||
flat_atomic_min_i32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_I32)
|
||||
flat_atomic_min_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_U32)
|
||||
flat_atomic_max_i32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_I32)
|
||||
flat_atomic_max_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_U32)
|
||||
flat_atomic_and_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_AND_B32)
|
||||
flat_atomic_or_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_OR_B32)
|
||||
flat_atomic_xor_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_XOR_B32)
|
||||
flat_atomic_inc_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_INC_U32)
|
||||
flat_atomic_dec_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_DEC_U32)
|
||||
flat_atomic_swap_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SWAP_B64)
|
||||
flat_atomic_cmpswap_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_CMPSWAP_B64)
|
||||
flat_atomic_add_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_ADD_U64)
|
||||
flat_atomic_sub_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SUB_U64)
|
||||
flat_atomic_min_i64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_I64)
|
||||
flat_atomic_min_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_U64)
|
||||
flat_atomic_max_i64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_I64)
|
||||
flat_atomic_max_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_U64)
|
||||
flat_atomic_and_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_AND_B64)
|
||||
flat_atomic_or_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_OR_B64)
|
||||
flat_atomic_xor_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_XOR_B64)
|
||||
flat_atomic_inc_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_INC_U64)
|
||||
flat_atomic_dec_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_DEC_U64)
|
||||
flat_atomic_cond_sub_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_COND_SUB_U32)
|
||||
flat_atomic_min_num_f32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_NUM_F32)
|
||||
flat_atomic_max_num_f32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_NUM_F32)
|
||||
flat_atomic_add_f32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_ADD_F32)
|
||||
flat_atomic_pk_add_f16 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_PK_ADD_F16)
|
||||
flat_atomic_pk_add_bf16 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_PK_ADD_BF16)
|
||||
global_load_u8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_U8)
|
||||
global_load_i8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_I8)
|
||||
global_load_u16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_U16)
|
||||
global_load_i16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_I16)
|
||||
global_load_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B32)
|
||||
global_load_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B64)
|
||||
global_load_b96 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B96)
|
||||
global_load_b128 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B128)
|
||||
global_store_b8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B8)
|
||||
global_store_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B16)
|
||||
global_store_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B32)
|
||||
global_store_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B64)
|
||||
global_store_b96 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B96)
|
||||
global_store_b128 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B128)
|
||||
global_load_d16_u8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_U8)
|
||||
global_load_d16_i8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_I8)
|
||||
global_load_d16_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_B16)
|
||||
global_load_d16_hi_u8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_HI_U8)
|
||||
global_load_d16_hi_i8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_HI_I8)
|
||||
global_load_d16_hi_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_HI_B16)
|
||||
global_store_d16_hi_b8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_D16_HI_B8)
|
||||
global_store_d16_hi_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_D16_HI_B16)
|
||||
global_load_addtid_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_ADDTID_B32)
|
||||
global_store_addtid_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_ADDTID_B32)
|
||||
global_inv = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_INV)
|
||||
global_wb = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_WB)
|
||||
global_atomic_swap_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SWAP_B32)
|
||||
global_atomic_cmpswap_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_CMPSWAP_B32)
|
||||
global_atomic_add_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ADD_U32)
|
||||
global_atomic_sub_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SUB_U32)
|
||||
global_atomic_sub_clamp_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SUB_CLAMP_U32)
|
||||
global_atomic_min_i32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_I32)
|
||||
global_atomic_min_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_U32)
|
||||
global_atomic_max_i32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_I32)
|
||||
global_atomic_max_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_U32)
|
||||
global_atomic_and_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_AND_B32)
|
||||
global_atomic_or_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_OR_B32)
|
||||
global_atomic_xor_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_XOR_B32)
|
||||
global_atomic_inc_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_INC_U32)
|
||||
global_atomic_dec_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_DEC_U32)
|
||||
global_atomic_swap_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SWAP_B64)
|
||||
global_atomic_cmpswap_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_CMPSWAP_B64)
|
||||
global_atomic_add_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ADD_U64)
|
||||
global_atomic_sub_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SUB_U64)
|
||||
global_atomic_min_i64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_I64)
|
||||
global_atomic_min_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_U64)
|
||||
global_atomic_max_i64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_I64)
|
||||
global_atomic_max_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_U64)
|
||||
global_atomic_and_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_AND_B64)
|
||||
global_atomic_or_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_OR_B64)
|
||||
global_atomic_xor_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_XOR_B64)
|
||||
global_atomic_inc_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_INC_U64)
|
||||
global_atomic_dec_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_DEC_U64)
|
||||
global_wbinv = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_WBINV)
|
||||
global_atomic_cond_sub_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_COND_SUB_U32)
|
||||
global_atomic_min_num_f32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_NUM_F32)
|
||||
global_atomic_max_num_f32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_NUM_F32)
|
||||
global_load_block = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_BLOCK)
|
||||
global_store_block = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_BLOCK)
|
||||
global_atomic_add_f32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ADD_F32)
|
||||
global_load_tr_b128 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_TR_B128)
|
||||
global_load_tr_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_TR_B64)
|
||||
global_atomic_pk_add_f16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_PK_ADD_F16)
|
||||
global_atomic_pk_add_bf16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_PK_ADD_BF16)
|
||||
global_atomic_ordered_add_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ORDERED_ADD_B64)
|
||||
image_load = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD)
|
||||
image_load_mip = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD_MIP)
|
||||
image_load_pck = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD_PCK)
|
||||
image_load_pck_sgn = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD_PCK_SGN)
|
||||
image_load_mip_pck = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD_MIP_PCK)
|
||||
image_load_mip_pck_sgn = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD_MIP_PCK_SGN)
|
||||
image_store = functools.partial(VIMAGE, VIMAGEOp.IMAGE_STORE)
|
||||
image_store_mip = functools.partial(VIMAGE, VIMAGEOp.IMAGE_STORE_MIP)
|
||||
image_store_pck = functools.partial(VIMAGE, VIMAGEOp.IMAGE_STORE_PCK)
|
||||
image_store_mip_pck = functools.partial(VIMAGE, VIMAGEOp.IMAGE_STORE_MIP_PCK)
|
||||
image_atomic_swap = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_SWAP)
|
||||
image_atomic_cmpswap = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_CMPSWAP)
|
||||
image_atomic_add_uint = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_ADD_UINT)
|
||||
image_atomic_sub_uint = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_SUB_UINT)
|
||||
image_atomic_min_int = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_MIN_INT)
|
||||
image_atomic_min_uint = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_MIN_UINT)
|
||||
image_atomic_max_int = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_MAX_INT)
|
||||
image_atomic_max_uint = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_MAX_UINT)
|
||||
image_atomic_and = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_AND)
|
||||
image_atomic_or = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_OR)
|
||||
image_atomic_xor = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_XOR)
|
||||
image_atomic_inc_uint = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_INC_UINT)
|
||||
image_atomic_dec_uint = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_DEC_UINT)
|
||||
image_get_resinfo = functools.partial(VIMAGE, VIMAGEOp.IMAGE_GET_RESINFO)
|
||||
image_bvh_intersect_ray = functools.partial(VIMAGE, VIMAGEOp.IMAGE_BVH_INTERSECT_RAY)
|
||||
image_bvh64_intersect_ray = functools.partial(VIMAGE, VIMAGEOp.IMAGE_BVH64_INTERSECT_RAY)
|
||||
image_bvh_dual_intersect_ray = functools.partial(VIMAGE, VIMAGEOp.IMAGE_BVH_DUAL_INTERSECT_RAY)
|
||||
image_bvh8_intersect_ray = functools.partial(VIMAGE, VIMAGEOp.IMAGE_BVH8_INTERSECT_RAY)
|
||||
image_atomic_add_flt = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_ADD_FLT)
|
||||
image_atomic_min_flt = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_MIN_FLT)
|
||||
image_atomic_max_flt = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_MAX_FLT)
|
||||
image_atomic_pk_add_f16 = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_PK_ADD_F16)
|
||||
image_atomic_pk_add_bf16 = functools.partial(VIMAGE, VIMAGEOp.IMAGE_ATOMIC_PK_ADD_BF16)
|
||||
v_interp_p10_f32 = functools.partial(VINTERP, VINTERPOp.V_INTERP_P10_F32)
|
||||
v_interp_p2_f32 = functools.partial(VINTERP, VINTERPOp.V_INTERP_P2_F32)
|
||||
v_interp_p10_f16_f32 = functools.partial(VINTERP, VINTERPOp.V_INTERP_P10_F16_F32)
|
||||
|
|
@ -1396,6 +1622,88 @@ v_dual_dot2acc_f32_bf16 = functools.partial(VOPD, VOPDOp.V_DUAL_DOT2ACC_F32_BF16
|
|||
v_dual_add_nc_u32 = functools.partial(VOPD, VOPDOp.V_DUAL_ADD_NC_U32)
|
||||
v_dual_lshlrev_b32 = functools.partial(VOPD, VOPDOp.V_DUAL_LSHLREV_B32)
|
||||
v_dual_and_b32 = functools.partial(VOPD, VOPDOp.V_DUAL_AND_B32)
|
||||
image_msaa_load = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_MSAA_LOAD)
|
||||
image_sample = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE)
|
||||
image_sample_d = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D)
|
||||
image_sample_l = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_L)
|
||||
image_sample_b = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_B)
|
||||
image_sample_lz = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_LZ)
|
||||
image_sample_c = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C)
|
||||
image_sample_c_d = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D)
|
||||
image_sample_c_l = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_L)
|
||||
image_sample_c_b = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_B)
|
||||
image_sample_c_lz = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_LZ)
|
||||
image_sample_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_O)
|
||||
image_sample_d_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D_O)
|
||||
image_sample_l_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_L_O)
|
||||
image_sample_b_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_B_O)
|
||||
image_sample_lz_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_LZ_O)
|
||||
image_sample_c_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_O)
|
||||
image_sample_c_d_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D_O)
|
||||
image_sample_c_l_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_L_O)
|
||||
image_sample_c_b_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_B_O)
|
||||
image_sample_c_lz_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_LZ_O)
|
||||
image_gather4 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4)
|
||||
image_gather4_l = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_L)
|
||||
image_gather4_b = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_B)
|
||||
image_gather4_lz = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_LZ)
|
||||
image_gather4_c = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C)
|
||||
image_gather4_c_lz = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_LZ)
|
||||
image_gather4_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_O)
|
||||
image_gather4_lz_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_LZ_O)
|
||||
image_gather4_c_lz_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_LZ_O)
|
||||
image_get_lod = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GET_LOD)
|
||||
image_sample_d_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D_G16)
|
||||
image_sample_c_d_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D_G16)
|
||||
image_sample_d_o_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D_O_G16)
|
||||
image_sample_c_d_o_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D_O_G16)
|
||||
image_sample_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_CL)
|
||||
image_sample_d_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D_CL)
|
||||
image_sample_b_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_B_CL)
|
||||
image_sample_c_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_CL)
|
||||
image_sample_c_d_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D_CL)
|
||||
image_sample_c_b_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_B_CL)
|
||||
image_sample_cl_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_CL_O)
|
||||
image_sample_d_cl_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D_CL_O)
|
||||
image_sample_b_cl_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_B_CL_O)
|
||||
image_sample_c_cl_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_CL_O)
|
||||
image_sample_c_d_cl_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D_CL_O)
|
||||
image_sample_c_b_cl_o = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_B_CL_O)
|
||||
image_sample_c_d_cl_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D_CL_G16)
|
||||
image_sample_d_cl_o_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D_CL_O_G16)
|
||||
image_sample_c_d_cl_o_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_C_D_CL_O_G16)
|
||||
image_sample_d_cl_g16 = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_SAMPLE_D_CL_G16)
|
||||
image_gather4_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_CL)
|
||||
image_gather4_b_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_B_CL)
|
||||
image_gather4_c_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_CL)
|
||||
image_gather4_c_l = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_L)
|
||||
image_gather4_c_b = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_B)
|
||||
image_gather4_c_b_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_B_CL)
|
||||
image_gather4h = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4H)
|
||||
scratch_load_u8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_U8)
|
||||
scratch_load_i8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_I8)
|
||||
scratch_load_u16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_U16)
|
||||
scratch_load_i16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_I16)
|
||||
scratch_load_b32 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B32)
|
||||
scratch_load_b64 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B64)
|
||||
scratch_load_b96 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B96)
|
||||
scratch_load_b128 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B128)
|
||||
scratch_store_b8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B8)
|
||||
scratch_store_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B16)
|
||||
scratch_store_b32 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B32)
|
||||
scratch_store_b64 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B64)
|
||||
scratch_store_b96 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B96)
|
||||
scratch_store_b128 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B128)
|
||||
scratch_load_d16_u8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_U8)
|
||||
scratch_load_d16_i8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_I8)
|
||||
scratch_load_d16_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_B16)
|
||||
scratch_load_d16_hi_u8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_HI_U8)
|
||||
scratch_load_d16_hi_i8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_HI_I8)
|
||||
scratch_load_d16_hi_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_HI_B16)
|
||||
scratch_store_d16_hi_b8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_D16_HI_B8)
|
||||
scratch_store_d16_hi_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_D16_HI_B16)
|
||||
scratch_load_block = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_BLOCK)
|
||||
scratch_store_block = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_BLOCK)
|
||||
|
||||
VCC_LO = SrcEnum.VCC_LO
|
||||
VCC_HI = SrcEnum.VCC_HI
|
||||
|
|
|
|||
|
|
@ -62,8 +62,10 @@ _SPECIAL_REGS = {
|
|||
'V_CMP_CLASS_F64': (1, 2, 1, 1), 'V_CMPX_CLASS_F64': (1, 2, 1, 1),
|
||||
'V_CMP_CLASS_F32': (1, 1, 1, 1), 'V_CMPX_CLASS_F32': (1, 1, 1, 1),
|
||||
'V_CMP_CLASS_F16': (1, 1, 1, 1), 'V_CMPX_CLASS_F16': (1, 1, 1, 1),
|
||||
'V_MAD_U64_U32': (2, 1, 1, 2), 'V_MAD_I64_I32': (2, 1, 1, 2),
|
||||
'V_MAD_U64_U32': (2, 1, 1, 2), 'V_MAD_I64_I32': (2, 1, 1, 2), 'V_MAD_CO_U64_U32': (2, 1, 1, 2), 'V_MAD_CO_I64_I32': (2, 1, 1, 2),
|
||||
'V_QSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_U32_U8': (4, 2, 1, 4),
|
||||
# RDNA4 CVT_PK_F32 instructions output 2 F32 values (64-bit)
|
||||
'V_CVT_PK_F32_BF8': (2, 1, 1, 1), 'V_CVT_PK_F32_FP8': (2, 1, 1, 1),
|
||||
}
|
||||
_SPECIAL_DTYPE = {
|
||||
'V_LSHLREV_B64': ('B64', 'U32', 'B64', None), 'V_LSHRREV_B64': ('B64', 'U32', 'B64', None), 'V_ASHRREV_I64': ('I64', 'U32', 'I64', None),
|
||||
|
|
@ -78,6 +80,8 @@ _SPECIAL_DTYPE = {
|
|||
'V_MAD_U64_U32': ('U64', 'U32', 'U32', 'U64'), 'V_MAD_I64_I32': ('I64', 'I32', 'I32', 'I64'),
|
||||
'V_QSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'), 'V_MQSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'),
|
||||
'V_MQSAD_U32_U8': ('B128', 'B64', 'B64', 'B128'),
|
||||
# RDNA4 CVT_PK_F32 instructions: source is 8-bit packed as 16-bit operand
|
||||
'V_CVT_PK_F32_BF8': ('F32', 'B16', None, None), 'V_CVT_PK_F32_FP8': ('F32', 'B16', None, None),
|
||||
}
|
||||
@cache
|
||||
def spec_regs(name: str) -> tuple[int, int, int, int]:
|
||||
|
|
@ -108,7 +112,7 @@ def spec_is_16bit(name: str) -> bool:
|
|||
def spec_is_64bit(name: str) -> bool: return bool(_F64_RE.search(name.upper()))
|
||||
_3SRC = {'FMA', 'MAD', 'MIN3', 'MAX3', 'MED3', 'DIV_FIX', 'DIV_FMAS', 'DIV_SCALE', 'SAD', 'LERP', 'ALIGN', 'CUBE', 'BFE', 'BFI',
|
||||
'PERM_B32', 'PERMLANE', 'CNDMASK', 'XOR3', 'OR3', 'ADD3', 'LSHL_OR', 'AND_OR', 'LSHL_ADD', 'ADD_LSHL', 'XAD', 'MAXMIN',
|
||||
'MINMAX', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'}
|
||||
'MINMAX', 'MAXIMUMMINIMUM', 'MINIMUMMAXIMUM', 'MAXIMUM3', 'MINIMUM3', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'}
|
||||
_2SRC = {'FMAC'} # FMAC uses dst as implicit accumulator, so only 2 explicit sources
|
||||
def spec_num_srcs(name: str) -> int:
|
||||
name = name.upper()
|
||||
|
|
@ -138,17 +142,35 @@ class BitField:
|
|||
def __get__(self, obj: None, objtype: type) -> BitField: ...
|
||||
@overload
|
||||
def __get__(self, obj: object, objtype: type | None = None) -> int: ...
|
||||
# Map RDNA4 class names to their corresponding enum names for op field dynamic lookup
|
||||
_RDNA4_OP_ENUMS = {'VDS': 'DSOp', 'VBUFFER': 'VBUFFEROp', 'VEXPORT': 'EXPOp', 'VFLAT': 'VFLATOp', 'VGLOBAL': 'VGLOBALOp',
|
||||
'VSCRATCH': 'VSCRATCHOp', 'VIMAGE': 'VIMAGEOp', 'VSAMPLE': 'VSAMPLEOp', 'VDSDIR': 'VDSDIROp'}
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
if obj is None: return self
|
||||
val = unwrap(obj._values.get(self.name, 0))
|
||||
# Convert to IntEnum if marker is an IntEnum subclass
|
||||
if self.marker and isinstance(self.marker, type) and issubclass(self.marker, IntEnum):
|
||||
# VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
|
||||
if self.marker is VOP3Op:
|
||||
if val < 256: return VOPCOp(val)
|
||||
if val in Inst._VOP3SD_OPS: return VOP3SDOp(val)
|
||||
# Check by name to handle both RDNA3 and RDNA4 enums
|
||||
if self.marker.__name__ == 'VOP3Op':
|
||||
# Get the appropriate enums from the same module as the marker
|
||||
marker_mod = self.marker.__module__
|
||||
import importlib
|
||||
enum_mod = importlib.import_module(marker_mod)
|
||||
if val < 256: return enum_mod.VOPCOp(val)
|
||||
if val in Inst._VOP3SD_OPS: return enum_mod.VOP3SDOp(val)
|
||||
try: return self.marker(val)
|
||||
except ValueError: pass
|
||||
# For RDNA4 op fields without type annotations, dynamically look up enum
|
||||
elif self.name == 'op' and 'rdna4' in obj.__class__.__module__:
|
||||
import importlib
|
||||
enum_mod = importlib.import_module('extra.assembly.amd.autogen.rdna4.enum')
|
||||
cls_name = obj.__class__.__name__
|
||||
enum_name = self._RDNA4_OP_ENUMS.get(cls_name, cls_name + 'Op')
|
||||
if hasattr(enum_mod, enum_name):
|
||||
try: return getattr(enum_mod, enum_name)(val)
|
||||
except ValueError: pass
|
||||
return val
|
||||
|
||||
class _Bits:
|
||||
|
|
@ -429,7 +451,7 @@ class Inst:
|
|||
return result + (lit32 & MASK32).to_bytes(4, 'little')
|
||||
|
||||
@classmethod
|
||||
def _size(cls) -> int: return 4 if issubclass(cls, Inst32) else 8
|
||||
def _size(cls) -> int: return 4 if issubclass(cls, Inst32) else 12 if issubclass(cls, Inst96) else 8
|
||||
def size(self) -> int:
|
||||
# Literal is always 4 bytes in the binary (for 64-bit ops, it's in high 32 bits)
|
||||
return self._size() + (4 if self._literal is not None else 0)
|
||||
|
|
@ -489,9 +511,9 @@ class Inst:
|
|||
|
||||
def __hash__(self): return hash((self.__class__.__name__, tuple(sorted((k, repr(v)) for k, v in self._values.items())), self._literal))
|
||||
|
||||
def disasm(self) -> str:
|
||||
def disasm(self, wave_size: int = 32) -> str:
|
||||
from extra.assembly.amd.asm import disasm
|
||||
return disasm(self)
|
||||
return disasm(self, wave_size)
|
||||
|
||||
_enum_map = {'VOP1': VOP1Op, 'VOP2': VOP2Op, 'VOP3': VOP3Op, 'VOP3SD': VOP3SDOp, 'VOP3P': VOP3POp, 'VOPC': VOPCOp,
|
||||
'SOP1': SOP1Op, 'SOP2': SOP2Op, 'SOPC': SOPCOp, 'SOPK': SOPKOp, 'SOPP': SOPPOp,
|
||||
|
|
@ -499,6 +521,10 @@ class Inst:
|
|||
'VOPD': VOPDOp, 'VINTERP': VINTERPOp}
|
||||
_VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
|
||||
|
||||
# Map RDNA4 class names to their corresponding enum names
|
||||
_rdna4_enum_names = {'VDS': 'DSOp', 'VBUFFER': 'VBUFFEROp', 'VEXPORT': 'EXPOp', 'VFLAT': 'VFLATOp', 'VGLOBAL': 'VGLOBALOp',
|
||||
'VSCRATCH': 'VSCRATCHOp', 'VIMAGE': 'VIMAGEOp', 'VSAMPLE': 'VSAMPLEOp', 'VDSDIR': 'VDSDIROp'}
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
"""Return the op as an enum (e.g., VOP1Op.V_MOV_B32). VOP3 returns VOPCOp/VOP3SDOp for those op ranges."""
|
||||
|
|
@ -506,6 +532,20 @@ class Inst:
|
|||
if val is None: return None
|
||||
if hasattr(val, 'name'): return val # already an enum
|
||||
cls_name = self.__class__.__name__
|
||||
# First check if op field has an annotated enum type
|
||||
import typing
|
||||
if 'op' in self.__class__.__annotations__:
|
||||
ann = self.__class__.__annotations__['op']
|
||||
if hasattr(ann, '__metadata__'):
|
||||
for m in typing.get_args(ann)[1:]:
|
||||
if isinstance(m, type) and issubclass(m, IntEnum): return m(val)
|
||||
# Check if this is an RDNA4 class (module path contains rdna4) and get enum from its module
|
||||
if 'rdna4' in self.__class__.__module__:
|
||||
import importlib
|
||||
enum_mod = importlib.import_module('extra.assembly.amd.autogen.rdna4.enum')
|
||||
enum_name = self._rdna4_enum_names.get(cls_name, cls_name + 'Op')
|
||||
if hasattr(enum_mod, enum_name): return getattr(enum_mod, enum_name)(val)
|
||||
# Fall back to static enum map
|
||||
assert cls_name in self._enum_map, f"no enum map for {cls_name}"
|
||||
return self._enum_map[cls_name](val)
|
||||
|
||||
|
|
@ -531,3 +571,4 @@ class Inst:
|
|||
|
||||
class Inst32(Inst): pass
|
||||
class Inst64(Inst): pass
|
||||
class Inst96(Inst): pass
|
||||
|
|
|
|||
|
|
@ -220,8 +220,16 @@ def _parse_fields_table(table: list, fmt: str, enums: set[str]) -> list[tuple]:
|
|||
if not (bits := _parse_bits(bits_str)): continue
|
||||
enc_val, hi, lo = None, bits[0], bits[1]
|
||||
if name == 'ENCODING' and row[2]:
|
||||
if m := re.search(r"(?:'b|Must be:\s*)([01_]+)", row[2]):
|
||||
desc = row[2]
|
||||
# Handle shared FLAT/GLOBAL/SCRATCH table: look for format-specific encoding
|
||||
fmt_key = fmt.lstrip('V').lower().capitalize() # VFLAT -> Flat, VGLOBAL -> Global
|
||||
if m := re.search(rf"{fmt_key}='b([01_]+)", desc):
|
||||
enc_bits = m.group(1).replace('_', '')
|
||||
elif m := re.search(r"(?:'b|Must be:\s*)([01_]+)", desc):
|
||||
enc_bits = m.group(1).replace('_', '')
|
||||
else:
|
||||
enc_bits = None
|
||||
if enc_bits:
|
||||
enc_val, declared_width, actual_width = int(enc_bits, 2), hi - lo + 1, len(enc_bits)
|
||||
if actual_width > declared_width: lo = hi - actual_width + 1
|
||||
ftype = f"{fmt}Op" if name == 'OP' and f"{fmt}Op" in enums else FIELD_TYPES.get(name.upper())
|
||||
|
|
@ -292,6 +300,16 @@ def _parse_single_pdf(url: str):
|
|||
next_text = pdf.text(microcode_start + i + 1).lstrip()
|
||||
if next_text.startswith('Description') or (next_text.startswith('"RDNA') and 'Description' in next_text[:200]):
|
||||
format_headers.append((fmt_name, i, m.start()))
|
||||
# RDNA4: Look for "Table X. Y Fields" patterns (e.g., VIMAGE, VSAMPLE, or shared FLAT/GLOBAL/SCRATCH)
|
||||
for m in re.finditer(r'Table \d+\.\s+([\w,\s]+?)\s+Fields', text):
|
||||
table_name = m.group(1).strip()
|
||||
# Handle shared table like "FLAT, GLOBAL and SCRATCH"
|
||||
if ',' in table_name or ' and ' in table_name:
|
||||
for part in re.split(r',\s*|\s+and\s+', table_name):
|
||||
fmt_name = 'V' + part.strip()
|
||||
if fmt_name not in [h[0] for h in format_headers]: format_headers.append((fmt_name, i, m.start()))
|
||||
elif table_name.startswith('V'):
|
||||
if table_name not in [h[0] for h in format_headers]: format_headers.append((table_name, i, m.start()))
|
||||
|
||||
formats: dict[str, list] = {}
|
||||
for fmt_name, rel_idx, header_pos in format_headers:
|
||||
|
|
@ -324,6 +342,10 @@ def _parse_single_pdf(url: str):
|
|||
if 'SMEM' in formats:
|
||||
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
|
||||
for n, h, l, e, t in formats['SMEM']]
|
||||
# RDNA4: VFLAT/VGLOBAL/VSCRATCH OP field is [20:14] not [20:13] (PDF documentation error)
|
||||
for fmt_name in ['VFLAT', 'VGLOBAL', 'VSCRATCH']:
|
||||
if fmt_name in formats:
|
||||
formats[fmt_name] = [(n, h, 14 if n == 'OP' else l, e, t) for n, h, l, e, t in formats[fmt_name]]
|
||||
if doc_name in ('RDNA3', 'RDNA3.5'):
|
||||
if 'SOPPOp' in enums: assert 8 not in enums['SOPPOp']; enums['SOPPOp'][8] = 'S_WAITCNT_DEPCTR'
|
||||
if 'DSOp' in enums:
|
||||
|
|
@ -409,13 +431,14 @@ def _generate_ins_py(formats, enums, src_enum, doc_name) -> str:
|
|||
def field_key(f, order): return order.index(f[0].lower()) if f[0].lower() in order else 1000
|
||||
lines = [f"# autogenerated from AMD {doc_name} ISA PDF by pdf.py - do not edit",
|
||||
"# ruff: noqa: F401,F403", "from typing import Annotated",
|
||||
"from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField",
|
||||
"from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, Inst96, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField",
|
||||
"from extra.assembly.amd.autogen.{arch}.enum import *",
|
||||
"import functools", ""]
|
||||
format_defaults = {'VOP3P': {'opsel_hi': 3, 'opsel_hi2': 1}}
|
||||
lines.append("# instruction formats")
|
||||
for fmt_name, fields in sorted(formats.items()):
|
||||
base = "Inst64" if max(f[1] for f in fields) > 31 or fmt_name == 'VOP3SD' else "Inst32"
|
||||
max_bit = max(f[1] for f in fields)
|
||||
base = "Inst96" if max_bit > 63 else "Inst64" if max_bit > 31 or fmt_name == 'VOP3SD' else "Inst32"
|
||||
order = FIELD_ORDER.get(fmt_name, [])
|
||||
lines.append(f"class {fmt_name}({base}):")
|
||||
if enc := next((f for f in fields if f[0] == 'ENCODING'), None):
|
||||
|
|
|
|||
|
|
@ -1,98 +1,76 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Test RDNA3 assembler/disassembler against LLVM test vectors."""
|
||||
"""Test RDNA3/RDNA4 assembler/disassembler against LLVM test vectors."""
|
||||
import unittest, re, subprocess
|
||||
from tinygrad.helpers import fetch
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.asm import asm
|
||||
from extra.assembly.amd.test.helpers import get_llvm_mc
|
||||
|
||||
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
|
||||
|
||||
# Format info: (filename, format_class, op_enum)
|
||||
LLVM_TEST_FILES = {
|
||||
# Scalar ALU
|
||||
'sop1': ('gfx11_asm_sop1.s', SOP1, SOP1Op),
|
||||
'sop2': ('gfx11_asm_sop2.s', SOP2, SOP2Op),
|
||||
'sopp': ('gfx11_asm_sopp.s', SOPP, SOPPOp),
|
||||
'sopk': ('gfx11_asm_sopk.s', SOPK, SOPKOp),
|
||||
'sopc': ('gfx11_asm_sopc.s', SOPC, SOPCOp),
|
||||
# Vector ALU
|
||||
'vop1': ('gfx11_asm_vop1.s', VOP1, VOP1Op),
|
||||
'vop2': ('gfx11_asm_vop2.s', VOP2, VOP2Op),
|
||||
'vopc': ('gfx11_asm_vopc.s', VOPC, VOPCOp),
|
||||
'vop3': ('gfx11_asm_vop3.s', VOP3, VOP3Op),
|
||||
'vop3p': ('gfx11_asm_vop3p.s', VOP3P, VOP3POp),
|
||||
'vop3sd': ('gfx11_asm_vop3.s', VOP3SD, VOP3SDOp), # VOP3SD shares file with VOP3
|
||||
'vinterp': ('gfx11_asm_vinterp.s', VINTERP, VINTERPOp),
|
||||
'vopd': ('gfx11_asm_vopd.s', VOPD, VOPDOp),
|
||||
'vopcx': ('gfx11_asm_vopcx.s', VOPC, VOPCOp), # VOPCX uses VOPC format
|
||||
# VOP3 promotions (VOP1/VOP2/VOPC promoted to VOP3 encoding)
|
||||
'vop3_from_vop1': ('gfx11_asm_vop3_from_vop1.s', VOP3, VOP3Op),
|
||||
'vop3_from_vop2': ('gfx11_asm_vop3_from_vop2.s', VOP3, VOP3Op),
|
||||
'vop3_from_vopc': ('gfx11_asm_vop3_from_vopc.s', VOP3, VOP3Op),
|
||||
'vop3_from_vopcx': ('gfx11_asm_vop3_from_vopcx.s', VOP3, VOP3Op),
|
||||
# Memory
|
||||
'ds': ('gfx11_asm_ds.s', DS, DSOp),
|
||||
'smem': ('gfx11_asm_smem.s', SMEM, SMEMOp),
|
||||
'flat': ('gfx11_asm_flat.s', FLAT, FLATOp),
|
||||
'mubuf': ('gfx11_asm_mubuf.s', MUBUF, MUBUFOp),
|
||||
'mtbuf': ('gfx11_asm_mtbuf.s', MTBUF, MTBUFOp),
|
||||
'mimg': ('gfx11_asm_mimg.s', MIMG, MIMGOp),
|
||||
# WMMA (matrix multiply)
|
||||
'wmma': ('gfx11_asm_wmma.s', VOP3P, VOP3POp),
|
||||
# Additional features
|
||||
'vop3_features': ('gfx11_asm_vop3_features.s', VOP3, VOP3Op),
|
||||
'vop3p_features': ('gfx11_asm_vop3p_features.s', VOP3P, VOP3POp),
|
||||
'vopd_features': ('gfx11_asm_vopd_features.s', VOPD, VOPDOp),
|
||||
# Alias files (alternative mnemonics)
|
||||
'vop3_alias': ('gfx11_asm_vop3_alias.s', VOP3, VOP3Op),
|
||||
'vop3p_alias': ('gfx11_asm_vop3p_alias.s', VOP3P, VOP3POp),
|
||||
'vopc_alias': ('gfx11_asm_vopc_alias.s', VOPC, VOPCOp),
|
||||
'vopcx_alias': ('gfx11_asm_vopcx_alias.s', VOPC, VOPCOp),
|
||||
'vinterp_alias': ('gfx11_asm_vinterp_alias.s', VINTERP, VINTERPOp),
|
||||
'smem_alias': ('gfx11_asm_smem_alias.s', SMEM, SMEMOp),
|
||||
'mubuf_alias': ('gfx11_asm_mubuf_alias.s', MUBUF, MUBUFOp),
|
||||
'mtbuf_alias': ('gfx11_asm_mtbuf_alias.s', MTBUF, MTBUFOp),
|
||||
RDNA3_TEST_FILES = {
|
||||
'sop1': 'gfx11_asm_sop1.s', 'sop2': 'gfx11_asm_sop2.s', 'sopp': 'gfx11_asm_sopp.s', 'sopk': 'gfx11_asm_sopk.s', 'sopc': 'gfx11_asm_sopc.s',
|
||||
'vop1': 'gfx11_asm_vop1.s', 'vop2': 'gfx11_asm_vop2.s', 'vopc': 'gfx11_asm_vopc.s', 'vop3': 'gfx11_asm_vop3.s', 'vop3p': 'gfx11_asm_vop3p.s',
|
||||
'vinterp': 'gfx11_asm_vinterp.s', 'vopd': 'gfx11_asm_vopd.s', 'vopcx': 'gfx11_asm_vopcx.s',
|
||||
'vop3_from_vop1': 'gfx11_asm_vop3_from_vop1.s', 'vop3_from_vop2': 'gfx11_asm_vop3_from_vop2.s',
|
||||
'vop3_from_vopc': 'gfx11_asm_vop3_from_vopc.s', 'vop3_from_vopcx': 'gfx11_asm_vop3_from_vopcx.s',
|
||||
'ds': 'gfx11_asm_ds.s', 'smem': 'gfx11_asm_smem.s', 'flat': 'gfx11_asm_flat.s',
|
||||
'mubuf': 'gfx11_asm_mubuf.s', 'mtbuf': 'gfx11_asm_mtbuf.s', 'mimg': 'gfx11_asm_mimg.s', 'mimg_features': 'gfx11_asm_mimg_features.s', 'ldsdir': 'gfx11_asm_ldsdir.s',
|
||||
'exp': 'gfx11_asm_exp.s', 'wmma': 'gfx11_asm_wmma.s',
|
||||
'vop3_features': 'gfx11_asm_vop3_features.s', 'vop3p_features': 'gfx11_asm_vop3p_features.s', 'vopd_features': 'gfx11_asm_vopd_features.s',
|
||||
'vop3_alias': 'gfx11_asm_vop3_alias.s', 'vop3p_alias': 'gfx11_asm_vop3p_alias.s', 'vopc_alias': 'gfx11_asm_vopc_alias.s',
|
||||
'vopcx_alias': 'gfx11_asm_vopcx_alias.s', 'vinterp_alias': 'gfx11_asm_vinterp_alias.s',
|
||||
'smem_alias': 'gfx11_asm_smem_alias.s', 'mubuf_alias': 'gfx11_asm_mubuf_alias.s', 'mtbuf_alias': 'gfx11_asm_mtbuf_alias.s',
|
||||
}
|
||||
|
||||
def parse_llvm_tests(text: str) -> list[tuple[str, bytes]]:
|
||||
RDNA4_TEST_FILES = {
|
||||
'sop1': 'gfx12_asm_sop1.s', 'sop2': 'gfx12_asm_sop2.s', 'sop2_alias': 'gfx12_asm_sop2_alias.s',
|
||||
'sopp': 'gfx12_asm_sopp.s', 'sopk': 'gfx12_asm_sopk.s', 'sopk_alias': 'gfx12_asm_sopk_alias.s', 'sopc': 'gfx12_asm_sopc.s',
|
||||
'vop1': 'gfx12_asm_vop1.s', 'vop2': 'gfx12_asm_vop2.s', 'vop2_aliases': 'gfx12_asm_vop2_aliases.s',
|
||||
'vopc': 'gfx12_asm_vopc.s', 'vopcx': 'gfx12_asm_vopcx.s',
|
||||
'vop3': 'gfx12_asm_vop3.s', 'vop3_aliases': 'gfx12_asm_vop3_aliases.s', 'vop3c': 'gfx12_asm_vop3c.s', 'vop3cx': 'gfx12_asm_vop3cx.s',
|
||||
'vop3p': 'gfx12_asm_vop3p.s', 'vop3p_aliases': 'gfx12_asm_vop3p_aliases.s', 'vop3p_features': 'gfx12_asm_vop3p_features.s',
|
||||
'vopd': 'gfx12_asm_vopd.s', 'vopd_features': 'gfx12_asm_vopd_features.s',
|
||||
'vop3_from_vop1': 'gfx12_asm_vop3_from_vop1.s', 'vop3_from_vop2': 'gfx12_asm_vop3_from_vop2.s',
|
||||
'ds': 'gfx12_asm_ds.s', 'ds_alias': 'gfx12_asm_ds_alias.s', 'smem': 'gfx12_asm_smem.s',
|
||||
'vflat': 'gfx12_asm_vflat.s', 'vflat_alias': 'gfx12_asm_vflat_alias.s',
|
||||
'vglobal': 'gfx12_asm_vflat.s', 'vglobal_alias': 'gfx12_asm_vflat_alias.s', # global instructions in vflat files
|
||||
'vscratch': 'gfx12_asm_vflat.s', # scratch instructions in vflat file
|
||||
'vbuffer_mubuf': 'gfx12_asm_vbuffer_mubuf.s', 'vbuffer_mubuf_alias': 'gfx12_asm_vbuffer_mubuf_alias.s',
|
||||
'vbuffer_mtbuf': 'gfx12_asm_vbuffer_mtbuf.s', 'vbuffer_mtbuf_alias': 'gfx12_asm_vbuffer_mtbuf_alias.s',
|
||||
'vimage': 'gfx12_asm_vimage.s', 'vimage_alias': 'gfx12_asm_vimage_alias.s', 'vsample': 'gfx12_asm_vsample.s',
|
||||
'vdsdir': 'gfx12_asm_vdsdir.s', 'vdsdir_alias': 'gfx12_asm_vdsdir_alias.s',
|
||||
'exp': 'gfx12_asm_exp.s', 'wmma_w32': 'gfx12_asm_wmma_w32.s', 'wmma_w64': 'gfx12_asm_wmma_w64.s',
|
||||
'global_load_tr': 'gfx12_asm_global_load_tr.s',
|
||||
# NOTE: 'features' (gfx12_asm_features.s) tests DPP instruction variants which require separate format decoders
|
||||
}
|
||||
|
||||
def parse_llvm_tests(text: str, gfx_prefix: str) -> list[tuple[str, bytes]]:
|
||||
"""Parse LLVM test format into (asm, expected_bytes) pairs."""
|
||||
tests, lines = [], text.split('\n')
|
||||
pattern = rf'(?:{gfx_prefix}|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]'
|
||||
pattern2 = rf'(?:{gfx_prefix}|W32|W64)[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]'
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
if not line or line.startswith(('//', '.', ';')): continue
|
||||
asm_text = line.split('//')[0].strip()
|
||||
if not asm_text: continue
|
||||
for j in range(i, min(i + 3, len(lines))):
|
||||
# Match GFX11, W32, or W64 encodings (all valid for gfx11)
|
||||
# Format 1: "// GFX11: v_foo ... ; encoding: [0x01,0x02,...]"
|
||||
# Format 2: "// GFX11: [0x01,0x02,...]" (used by DS, older files)
|
||||
if m := re.search(r'(?:GFX11|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
|
||||
if m := re.search(pattern, lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
elif m := re.search(r'(?:GFX11|W32|W64)[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
|
||||
elif m := re.search(pattern2, lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
else:
|
||||
continue
|
||||
else: continue
|
||||
if hex_bytes:
|
||||
try: tests.append((asm_text, bytes.fromhex(hex_bytes)))
|
||||
except ValueError: pass
|
||||
break
|
||||
return tests
|
||||
|
||||
def try_assemble(text: str):
|
||||
"""Try to assemble instruction text, return bytes or None on failure."""
|
||||
try: return asm(text).to_bytes()
|
||||
except: return None
|
||||
|
||||
def compile_asm_batch(instrs: list[str]) -> list[bytes]:
|
||||
def compile_asm_batch(instrs: list[str], mcpu: str, mattr: str = '+real-true16,+wavefrontsize32') -> list[bytes]:
|
||||
"""Compile multiple instructions with a single llvm-mc call."""
|
||||
if not instrs: return []
|
||||
asm_text = ".text\n" + "\n".join(instrs) + "\n"
|
||||
result = subprocess.run(
|
||||
[get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=asm_text, capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
|
||||
# Parse all encodings from output
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', f'-mattr={mattr}', '-show-encoding'],
|
||||
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
|
||||
results = []
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'encoding:' not in line: continue
|
||||
|
|
@ -102,100 +80,127 @@ def compile_asm_batch(instrs: list[str]) -> list[bytes]:
|
|||
if len(results) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(results)}")
|
||||
return results
|
||||
|
||||
class TestLLVM(unittest.TestCase):
|
||||
"""Test assembler and disassembler against all LLVM test vectors."""
|
||||
def matches_encoding(data: bytes, fmt) -> bool:
|
||||
"""Check if instruction bytes match format's expected encoding bits."""
|
||||
if not hasattr(fmt, '_encoding') or fmt._encoding is None: return True
|
||||
bf, expected = fmt._encoding
|
||||
val = int.from_bytes(data[:fmt._size()], 'little')
|
||||
return ((val >> bf.lo) & bf.mask()) == expected
|
||||
|
||||
class TestLLVMBase(unittest.TestCase):
|
||||
"""Base class for LLVM assembler tests."""
|
||||
tests: dict[str, list[tuple[str, bytes]]] = {}
|
||||
formats: dict[str, type] = {}
|
||||
gfx_prefix: str = ""
|
||||
mcpu: str = ""
|
||||
arch_name: str = ""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
for name, (filename, _, _) in LLVM_TEST_FILES.items():
|
||||
def _load_tests(cls, test_files: dict[str, str]):
|
||||
for name, filename in test_files.items():
|
||||
try:
|
||||
data = fetch(f"{LLVM_BASE}/{filename}").read_bytes()
|
||||
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'))
|
||||
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'), cls.gfx_prefix)
|
||||
except Exception as e:
|
||||
print(f"Warning: couldn't fetch {filename}: {e}")
|
||||
cls.tests[name] = []
|
||||
|
||||
# Generate test methods dynamically for each format
|
||||
def _make_asm_test(name):
|
||||
def test(self):
|
||||
passed, failed, skipped = 0, 0, 0
|
||||
for asm_text, expected in self.tests.get(name, []):
|
||||
result = try_assemble(asm_text)
|
||||
if result is None: skipped += 1
|
||||
elif result == expected: passed += 1
|
||||
else: failed += 1
|
||||
print(f"{name.upper()} asm: {passed} passed, {failed} failed, {skipped} skipped")
|
||||
self.assertEqual(failed, 0)
|
||||
return test
|
||||
def _test_disasm(self, name: str):
|
||||
"""Test decoding instructions and verify disassembly produces correct bytes."""
|
||||
if name not in self.tests or not self.tests[name]: self.skipTest(f"No test data for {name}")
|
||||
fmt_cls = self.formats.get(name)
|
||||
if fmt_cls is None: self.skipTest(f"No format class for {name}")
|
||||
|
||||
def _make_disasm_test(name):
|
||||
def test(self):
|
||||
_, fmt_cls, op_enum = LLVM_TEST_FILES[name]
|
||||
# VOP3SD opcodes that share encoding with VOP3 (only for vop3sd test, not vopc promotions)
|
||||
vop3sd_opcodes = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
|
||||
is_vopc_promotion = name in ('vop3_from_vopc', 'vop3_from_vopcx')
|
||||
undocumented = {'smem': {34, 35}, 'sopk': {22, 23}, 'sopp': {8, 58, 59}}
|
||||
# Determine wave size from test name (w64 = wave64, otherwise wave32)
|
||||
wave_size = 64 if 'w64' in name else 32
|
||||
mattr = f'+real-true16,+wavefrontsize{wave_size}'
|
||||
|
||||
# First pass: decode all instructions and collect disasm strings
|
||||
to_test: list[tuple[str, bytes, str | None, str | None]] = [] # (asm_text, data, disasm_str, error)
|
||||
skipped = 0
|
||||
to_test: list[tuple[str, bytes, str | None, str | None]] = []
|
||||
for asm_text, data in self.tests.get(name, []):
|
||||
if len(data) > fmt_cls._size(): continue
|
||||
temp_inst = fmt_cls.from_bytes(data)
|
||||
temp_op = temp_inst._values.get('op', 0)
|
||||
temp_op = temp_op.val if hasattr(temp_op, 'val') else temp_op
|
||||
if temp_op in undocumented.get(name, set()): skipped += 1; continue
|
||||
if name == 'sopp':
|
||||
simm16 = temp_inst._values.get('simm16', 0)
|
||||
simm16 = simm16.val if hasattr(simm16, 'val') else simm16
|
||||
sopp_no_imm = {48, 54, 53, 55, 60, 61, 62}
|
||||
if temp_op in sopp_no_imm and simm16 != 0: skipped += 1; continue
|
||||
if not matches_encoding(data, fmt_cls): continue
|
||||
try:
|
||||
if fmt_cls.__name__ in ('VOP3', 'VOP3SD'):
|
||||
temp = VOP3.from_bytes(data)
|
||||
op_val = temp._values.get('op', 0)
|
||||
op_val = op_val.val if hasattr(op_val, 'val') else op_val
|
||||
is_vop3sd = (op_val in vop3sd_opcodes) and not is_vopc_promotion
|
||||
decoded = VOP3SD.from_bytes(data) if is_vop3sd else VOP3.from_bytes(data)
|
||||
if is_vop3sd: VOP3SDOp(op_val)
|
||||
else: VOP3Op(op_val)
|
||||
else:
|
||||
decoded = fmt_cls.from_bytes(data)
|
||||
op_val = decoded._values.get('op', 0)
|
||||
op_val = op_val.val if hasattr(op_val, 'val') else op_val
|
||||
op_enum(op_val)
|
||||
decoded = fmt_cls.from_bytes(data)
|
||||
if decoded.to_bytes()[:len(data)] != data:
|
||||
to_test.append((asm_text, data, None, "decode roundtrip failed"))
|
||||
continue
|
||||
to_test.append((asm_text, data, decoded.disasm(), None))
|
||||
to_test.append((asm_text, data, decoded.disasm(wave_size), None))
|
||||
except Exception as e:
|
||||
to_test.append((asm_text, data, None, f"exception: {e}"))
|
||||
|
||||
# Batch compile all disasm strings with single llvm-mc call
|
||||
disasm_strs = [(i, t[2]) for i, t in enumerate(to_test) if t[2] is not None]
|
||||
llvm_results = compile_asm_batch([s for _, s in disasm_strs]) if disasm_strs else []
|
||||
llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)}
|
||||
llvm_map = {}
|
||||
if disasm_strs:
|
||||
llvm_results = compile_asm_batch([s for _, s in disasm_strs], self.mcpu, mattr)
|
||||
llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)}
|
||||
|
||||
# Match results back
|
||||
passed, failed = 0, 0
|
||||
failures: list[str] = []
|
||||
passed, failed, failures = 0, 0, []
|
||||
for idx, (asm_text, data, disasm_str, error) in enumerate(to_test):
|
||||
if error:
|
||||
failed += 1; failures.append(f"{error} for {data.hex()}")
|
||||
elif disasm_str is not None and idx in llvm_map:
|
||||
llvm_bytes = llvm_map[idx]
|
||||
if llvm_bytes is not None and llvm_bytes == data: passed += 1
|
||||
elif llvm_bytes is not None: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
|
||||
if llvm_bytes == data: passed += 1
|
||||
else: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
|
||||
|
||||
print(f"{name.upper()} disasm: {passed} passed, {failed} failed" + (f", {skipped} skipped" if skipped else ""))
|
||||
if failures[:10]: print(" " + "\n ".join(failures[:10]))
|
||||
self.assertEqual(failed, 0)
|
||||
print(f"{self.arch_name} {name.upper()} disasm: {passed} passed, {failed} failed")
|
||||
if failures[:5]: print(" " + "\n ".join(failures[:5]))
|
||||
self.assertGreater(passed, 0, f"No tests passed for {name}")
|
||||
|
||||
class TestLLVMRDNA3(TestLLVMBase):
|
||||
"""Test RDNA3 assembler against LLVM test vectors."""
|
||||
gfx_prefix, mcpu, arch_name = "GFX11", "gfx1100", "RDNA3"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
from extra.assembly.amd.autogen.rdna3.ins import SOP1, SOP2, SOPC, SOPK, SOPP, VOP1, VOP2, VOP3, VOP3P, VOPC, VOPD, VINTERP, DS, SMEM, FLAT, MUBUF, MTBUF, MIMG, LDSDIR, EXP
|
||||
cls.formats = {
|
||||
'sop1': SOP1, 'sop2': SOP2, 'sopc': SOPC, 'sopk': SOPK, 'sopp': SOPP,
|
||||
'vop1': VOP1, 'vop2': VOP2, 'vopc': VOPC, 'vopcx': VOPC, 'vop3': VOP3, 'vop3p': VOP3P,
|
||||
'vinterp': VINTERP, 'vopd': VOPD, 'ds': DS, 'smem': SMEM, 'flat': FLAT,
|
||||
'mubuf': MUBUF, 'mtbuf': MTBUF, 'mimg': MIMG, 'mimg_features': MIMG, 'wmma': VOP3P, 'ldsdir': LDSDIR, 'exp': EXP,
|
||||
'vop3_from_vop1': VOP3, 'vop3_from_vop2': VOP3, 'vop3_from_vopc': VOP3, 'vop3_from_vopcx': VOP3,
|
||||
'vop3_features': VOP3, 'vop3p_features': VOP3P, 'vopd_features': VOPD,
|
||||
'vop3_alias': VOP3, 'vop3p_alias': VOP3P, 'vopc_alias': VOPC, 'vopcx_alias': VOPC,
|
||||
'vinterp_alias': VINTERP, 'smem_alias': SMEM, 'mubuf_alias': MUBUF, 'mtbuf_alias': MTBUF,
|
||||
}
|
||||
cls._load_tests(RDNA3_TEST_FILES)
|
||||
|
||||
class TestLLVMRDNA4(TestLLVMBase):
|
||||
"""Test RDNA4 assembler against LLVM test vectors."""
|
||||
gfx_prefix, mcpu, arch_name = "GFX12", "gfx1200", "RDNA4"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
import extra.assembly.amd.autogen.rdna4.ins as rdna4
|
||||
get = lambda n: getattr(rdna4, n, None)
|
||||
cls.formats = {
|
||||
'sop1': get('SOP1'), 'sop2': get('SOP2'), 'sop2_alias': get('SOP2'), 'sopc': get('SOPC'),
|
||||
'sopk': get('SOPK'), 'sopk_alias': get('SOPK'), 'sopp': get('SOPP'),
|
||||
'vop1': get('VOP1'), 'vop2': get('VOP2'), 'vop2_aliases': get('VOP2'), 'vopc': get('VOPC'), 'vopcx': get('VOPC'),
|
||||
'vop3': get('VOP3'), 'vop3_aliases': get('VOP3'), 'vop3c': get('VOP3'), 'vop3cx': get('VOP3'),
|
||||
'vop3p': get('VOP3P'), 'vop3p_aliases': get('VOP3P'), 'vop3p_features': get('VOP3P'),
|
||||
'vopd': get('VOPD'), 'vopd_features': get('VOPD'),
|
||||
'vop3_from_vop1': get('VOP3'), 'vop3_from_vop2': get('VOP3'),
|
||||
'ds': get('VDS'), 'ds_alias': get('VDS'), 'smem': get('SMEM'), 'vinterp': get('VINTERP'), 'exp': get('VEXPORT'),
|
||||
'vbuffer_mubuf': get('VBUFFER'), 'vbuffer_mubuf_alias': get('VBUFFER'),
|
||||
'vbuffer_mtbuf': get('VBUFFER'), 'vbuffer_mtbuf_alias': get('VBUFFER'),
|
||||
'vdsdir': get('VDSDIR'), 'vdsdir_alias': get('VDSDIR'),
|
||||
'vflat': get('VFLAT'), 'vflat_alias': get('VFLAT'),
|
||||
'vglobal': get('VGLOBAL'), 'vglobal_alias': get('VGLOBAL'),
|
||||
'vscratch': get('VSCRATCH'),
|
||||
'vimage': get('VIMAGE'), 'vimage_alias': get('VIMAGE'), 'vsample': get('VSAMPLE'),
|
||||
'wmma_w32': get('VOP3P'), 'wmma_w64': get('VOP3P'),
|
||||
'global_load_tr': get('VGLOBAL'),
|
||||
}
|
||||
cls._load_tests(RDNA4_TEST_FILES)
|
||||
|
||||
# Generate test methods dynamically
|
||||
def _make_test(name):
|
||||
def test(self): self._test_disasm(name)
|
||||
return test
|
||||
|
||||
for name in LLVM_TEST_FILES:
|
||||
setattr(TestLLVM, f'test_{name}_asm', _make_asm_test(name))
|
||||
setattr(TestLLVM, f'test_{name}_disasm', _make_disasm_test(name))
|
||||
for name in RDNA3_TEST_FILES: setattr(TestLLVMRDNA3, f'test_{name}_disasm', _make_test(name))
|
||||
for name in RDNA4_TEST_FILES: setattr(TestLLVMRDNA4, f'test_{name}_disasm', _make_test(name))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
if __name__ == "__main__": unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,90 +1,37 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match."""
|
||||
import unittest, io, sys, re, subprocess, os
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.dsl import Inst
|
||||
from extra.assembly.amd.asm import asm
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump
|
||||
|
||||
def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
|
||||
"""Disassemble ELF binary and return list of (instruction_text, machine_code_bytes)."""
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = io.StringIO()
|
||||
compiler.disassemble(lib)
|
||||
output = sys.stdout.getvalue()
|
||||
sys.stdout = old_stdout
|
||||
|
||||
results = []
|
||||
for line in output.splitlines():
|
||||
if '//' not in line: continue
|
||||
instr = line.split('//')[0].strip()
|
||||
if not instr: continue
|
||||
comment = line.split('//')[1].strip()
|
||||
if ':' not in comment: continue
|
||||
hex_str = comment.split(':')[1].strip().split()[0]
|
||||
try:
|
||||
machine_bytes = bytes.fromhex(hex_str)[::-1] # big-endian to little-endian
|
||||
results.append((instr, machine_bytes))
|
||||
except ValueError:
|
||||
continue
|
||||
return results
|
||||
|
||||
def compile_asm(instr: str, compiler=None) -> bytes:
|
||||
"""Compile a single instruction with llvm-mc and return the machine code bytes."""
|
||||
llvm_mc = get_llvm_mc()
|
||||
result = subprocess.run(
|
||||
[llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=f".text\n{instr}\n", capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed for '{instr}': {result.stderr.strip()}")
|
||||
# Parse encoding: [0x01,0x39,0x0a,0x7e]
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'encoding:' in line:
|
||||
enc = line.split('encoding:')[1].strip()
|
||||
if enc.startswith('[') and enc.endswith(']'):
|
||||
hex_vals = enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')
|
||||
return bytes.fromhex(hex_vals)
|
||||
raise RuntimeError(f"no encoding found in llvm-mc output for: {instr}")
|
||||
|
||||
def compile_asm_batch(instrs: list[str]) -> list[bytes]:
|
||||
def compile_asm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[bytes]:
|
||||
"""Compile multiple instructions with a single llvm-mc call."""
|
||||
if not instrs: return []
|
||||
llvm_mc = get_llvm_mc()
|
||||
src = ".text\n" + "\n".join(instrs) + "\n"
|
||||
result = subprocess.run(
|
||||
[llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=src, capture_output=True, text=True)
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
|
||||
# Parse all encodings in order
|
||||
encodings = []
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'encoding:' in line:
|
||||
enc = line.split('encoding:')[1].strip()
|
||||
if enc.startswith('[') and enc.endswith(']'):
|
||||
hex_vals = enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')
|
||||
encodings.append(bytes.fromhex(hex_vals))
|
||||
encodings.append(bytes.fromhex(enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')))
|
||||
if len(encodings) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(encodings)}")
|
||||
return encodings
|
||||
|
||||
def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str]:
|
||||
def compile_and_disasm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[str]:
|
||||
"""Compile instructions with LLVM and get LLVM's disassembly."""
|
||||
import tempfile, os
|
||||
import tempfile
|
||||
if not instrs: return []
|
||||
# Build assembly source with all instructions
|
||||
src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n"
|
||||
src += "\n".join(f" {instr}" for instr in instrs) + "\n"
|
||||
# Use llvm-mc to assemble to object file
|
||||
src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n" + "\n".join(f" {instr}" for instr in instrs) + "\n"
|
||||
with tempfile.NamedTemporaryFile(suffix='.o', delete=False) as f:
|
||||
obj_path = f.name
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path],
|
||||
input=src, capture_output=True, text=True)
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path],
|
||||
input=src, capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
|
||||
# Disassemble with llvm-objdump
|
||||
result = subprocess.run([get_llvm_objdump(), '-d', '--mcpu=gfx1100', obj_path], capture_output=True, text=True)
|
||||
result = subprocess.run([get_llvm_objdump(), '-d', f'--mcpu={mcpu}', obj_path], capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-objdump failed: {result.stderr.strip()}")
|
||||
# Parse disassembly output
|
||||
results: list[str] = []
|
||||
for line in result.stdout.splitlines():
|
||||
if '//' not in line: continue
|
||||
|
|
@ -94,127 +41,143 @@ def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str]:
|
|||
finally:
|
||||
os.unlink(obj_path)
|
||||
|
||||
class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
"""Test roundtrip on real tinygrad-generated kernels using get_kernels_from_tinygrad pattern."""
|
||||
class TestRoundtripBase(unittest.TestCase):
|
||||
"""Base class for roundtrip tests."""
|
||||
mcpu: str = 'gfx1100'
|
||||
arch: str = 'rdna3'
|
||||
|
||||
@classmethod
|
||||
def _get_modules(cls):
|
||||
if cls.arch == 'rdna3':
|
||||
from extra.assembly.amd.autogen.rdna3 import ins
|
||||
from extra.assembly.amd.asm import detect_format, asm
|
||||
else:
|
||||
import extra.assembly.amd.autogen.rdna4.ins as ins
|
||||
from extra.assembly.amd.asm import asm
|
||||
detect_format = None # RDNA4 uses different detection
|
||||
return ins, detect_format, asm
|
||||
|
||||
def _test_kernel_roundtrip(self, op_fn):
|
||||
"""Generate kernel from op_fn, test:
|
||||
1. decode -> reencode matches original bytes
|
||||
2. asm(disasm()) matches LLVM output
|
||||
3. our disasm() matches LLVM's disassembly string exactly
|
||||
"""
|
||||
"""Generate kernel from op_fn, test decode -> reencode and asm(disasm()) matches LLVM."""
|
||||
from extra.assembly.amd.test.test_compare_emulators import get_kernels_from_tinygrad
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
|
||||
ins, detect_format, asm = self._get_modules()
|
||||
kernels, _, _ = get_kernels_from_tinygrad(op_fn)
|
||||
compiler = HIPCompiler('gfx1100')
|
||||
compiler = HIPCompiler(self.mcpu)
|
||||
|
||||
# First pass: decode all instructions and collect info
|
||||
decoded_instrs: list[tuple] = [] # list of (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err)
|
||||
# First pass: decode all instructions
|
||||
decoded_instrs: list[tuple] = []
|
||||
for ki, kernel in enumerate(kernels):
|
||||
offset = 0
|
||||
while offset < len(kernel.code):
|
||||
remaining = kernel.code[offset:]
|
||||
fmt = detect_format(remaining)
|
||||
if fmt is None:
|
||||
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
|
||||
offset += 4
|
||||
continue
|
||||
if len(remaining) < 4: break
|
||||
|
||||
# Try to detect format
|
||||
if detect_format is not None:
|
||||
try:
|
||||
fmt = detect_format(remaining)
|
||||
except ValueError:
|
||||
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
|
||||
offset += 4
|
||||
continue
|
||||
else:
|
||||
# For RDNA4, try formats in order
|
||||
fmt = None
|
||||
from extra.assembly.amd.autogen.rdna4.ins import SOP1, SOP2, SOPC, SOPK, SOPP, VOP1, VOP2, VOP3, VOP3P, VOPC, VOPD, VDS, SMEM, VFLAT, VBUFFER, VIMAGE, VSAMPLE, VEXPORT, VDSDIR
|
||||
word = int.from_bytes(remaining[:4], 'little')
|
||||
for cls in [VOPD, VOP3P, VOP3, VDS, VFLAT, VBUFFER, VIMAGE, VSAMPLE, SMEM, VEXPORT, SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2, VDSDIR]:
|
||||
if cls._encoding is not None:
|
||||
bf, val = cls._encoding
|
||||
if ((word >> bf.lo) & bf.mask()) == val:
|
||||
fmt = cls
|
||||
break
|
||||
if fmt is None:
|
||||
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
|
||||
offset += 4
|
||||
continue
|
||||
|
||||
base_size = fmt._size()
|
||||
if len(remaining) < base_size:
|
||||
break
|
||||
if len(remaining) < base_size: break
|
||||
|
||||
try:
|
||||
decoded = fmt.from_bytes(remaining) # pass all remaining bytes so from_bytes can read literal
|
||||
size = decoded.size() # actual size including literal
|
||||
decoded = fmt.from_bytes(remaining)
|
||||
size = decoded.size()
|
||||
orig_bytes = remaining[:size]
|
||||
reencoded = decoded.to_bytes()
|
||||
our_disasm = decoded.disasm()
|
||||
decode_ok = reencoded == orig_bytes
|
||||
decode_err: str | None = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}"
|
||||
decode_err = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}"
|
||||
decoded_instrs.append((ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err))
|
||||
except Exception as e:
|
||||
decoded_instrs.append((ki, offset, remaining[:base_size], None, None, False, str(e)))
|
||||
size = base_size
|
||||
|
||||
offset += size
|
||||
|
||||
# Collect disasm strings for batched LLVM calls - skip unknown opcodes (op_X) that LLVM can't compile
|
||||
asm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for asm test
|
||||
disasm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for disasm comparison test
|
||||
|
||||
# Collect disasm strings for batched LLVM calls
|
||||
asm_test_instrs: list[tuple[int, str]] = []
|
||||
for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs):
|
||||
if our_disasm is None: continue
|
||||
# Skip unknown opcodes and malformed instructions for both tests
|
||||
if our_disasm.startswith('op_') or re.search(r', \d+, \d+, \d+,', our_disasm): continue
|
||||
asm_test_instrs.append((idx, our_disasm))
|
||||
disasm_test_instrs.append((idx, our_disasm))
|
||||
|
||||
# Batch compile for asm test
|
||||
asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs])
|
||||
asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs], self.mcpu)
|
||||
asm_llvm_map = {idx: result for (idx, _), result in zip(asm_test_instrs, asm_llvm_results)}
|
||||
|
||||
# Batch compile+disasm for disasm comparison test
|
||||
disasm_llvm_results = compile_and_disasm_batch([d for _, d in disasm_test_instrs], compiler)
|
||||
disasm_llvm_map = {idx: result for (idx, _), result in zip(disasm_test_instrs, disasm_llvm_results)}
|
||||
disasm_llvm_results = compile_and_disasm_batch([d for _, d in asm_test_instrs], self.mcpu)
|
||||
disasm_llvm_map = {idx: result for (idx, _), result in zip(asm_test_instrs, disasm_llvm_results)}
|
||||
|
||||
# Now evaluate results
|
||||
# Evaluate results
|
||||
decode_passed, decode_failed, decode_skipped = 0, 0, 0
|
||||
asm_passed, asm_failed, asm_skipped = 0, 0, 0
|
||||
disasm_passed, disasm_failed, disasm_skipped = 0, 0, 0
|
||||
decode_failures: list[str] = []
|
||||
asm_failures: list[str] = []
|
||||
disasm_failures: list[str] = []
|
||||
decode_failures, asm_failures, disasm_failures = [], [], []
|
||||
|
||||
for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs):
|
||||
# Decode test
|
||||
if decode_ok:
|
||||
decode_passed += 1
|
||||
elif decode_err == "no format":
|
||||
decode_skipped += 1
|
||||
if decode_ok: decode_passed += 1
|
||||
elif decode_err == "no format": decode_skipped += 1
|
||||
else:
|
||||
decode_failed += 1
|
||||
decode_failures.append(f"K{ki}@{offset}: {our_disasm}: {decode_err}")
|
||||
|
||||
# Asm test
|
||||
if our_disasm is None:
|
||||
asm_skipped += 1
|
||||
disasm_skipped += 1
|
||||
elif idx in asm_llvm_map:
|
||||
llvm_bytes = asm_llvm_map[idx]
|
||||
try:
|
||||
our_bytes = asm(our_disasm).to_bytes()
|
||||
if our_bytes[:len(llvm_bytes)] == llvm_bytes:
|
||||
asm_passed += 1
|
||||
if our_bytes[:len(llvm_bytes)] == llvm_bytes: asm_passed += 1
|
||||
else:
|
||||
asm_failed += 1
|
||||
asm_failures.append(f"K{ki}@{offset}: '{our_disasm}': ours={our_bytes[:len(llvm_bytes)].hex()} llvm={llvm_bytes.hex()}")
|
||||
except Exception:
|
||||
asm_skipped += 1
|
||||
|
||||
if idx in disasm_llvm_map:
|
||||
if our_disasm == disasm_llvm_map[idx]: disasm_passed += 1
|
||||
else:
|
||||
disasm_failed += 1
|
||||
disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{disasm_llvm_map[idx]}'")
|
||||
else:
|
||||
disasm_skipped += 1
|
||||
else:
|
||||
asm_skipped += 1
|
||||
|
||||
# Disasm comparison test
|
||||
if our_disasm is None:
|
||||
disasm_skipped += 1
|
||||
elif idx in disasm_llvm_map:
|
||||
llvm_disasm = disasm_llvm_map[idx]
|
||||
if our_disasm == llvm_disasm:
|
||||
disasm_passed += 1
|
||||
else:
|
||||
disasm_failed += 1
|
||||
disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{llvm_disasm}'")
|
||||
else:
|
||||
disasm_skipped += 1
|
||||
|
||||
print(f"decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped")
|
||||
print(f"asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
|
||||
print(f"disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped")
|
||||
print(f"{self.arch.upper()} decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped")
|
||||
print(f"{self.arch.upper()} asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
|
||||
print(f"{self.arch.upper()} disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped")
|
||||
self.assertEqual(decode_failed, 0, f"Decode failures:\n" + "\n".join(decode_failures[:20]))
|
||||
self.assertEqual(asm_failed, 0, f"Asm failures:\n" + "\n".join(asm_failures[:20]))
|
||||
# Note: disasm string comparison is informational only - formatting differences between LLVM versions are expected
|
||||
|
||||
# Basic unary ops
|
||||
class TestRoundtripRDNA3(TestRoundtripBase):
|
||||
"""Roundtrip tests for RDNA3 (gfx1100)."""
|
||||
mcpu, arch = 'gfx1100', 'rdna3'
|
||||
|
||||
def test_neg(self): self._test_kernel_roundtrip(lambda T: -T([1.0, -2.0, 3.0, -4.0]))
|
||||
def test_relu(self): self._test_kernel_roundtrip(lambda T: T([-1.0, 0.0, 1.0, 2.0]).relu())
|
||||
def test_exp(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).exp())
|
||||
|
|
@ -222,42 +185,62 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
|||
def test_sin(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).sin())
|
||||
def test_sqrt(self): self._test_kernel_roundtrip(lambda T: T([1.0, 4.0, 9.0]).sqrt())
|
||||
def test_recip(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0, 4.0]).reciprocal())
|
||||
|
||||
# Binary ops
|
||||
def test_add(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0]) + T([3.0, 4.0]))
|
||||
def test_sub(self): self._test_kernel_roundtrip(lambda T: T([5.0, 6.0]) - T([1.0, 2.0]))
|
||||
def test_mul(self): self._test_kernel_roundtrip(lambda T: T([2.0, 3.0]) * T([4.0, 5.0]))
|
||||
def test_div(self): self._test_kernel_roundtrip(lambda T: T([10.0, 20.0]) / T([2.0, 4.0]))
|
||||
def test_max_binary(self): self._test_kernel_roundtrip(lambda T: T([1.0, 5.0]).maximum(T([3.0, 2.0])))
|
||||
|
||||
# Reductions
|
||||
def test_sum_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).sum())
|
||||
def test_max_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).max())
|
||||
def test_mean_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(32).mean())
|
||||
|
||||
# Matmul
|
||||
def test_gemm_4x4(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4) @ T.empty(4, 4))
|
||||
def test_gemv(self): self._test_kernel_roundtrip(lambda T: T.empty(1, 16) @ T.empty(16, 16))
|
||||
|
||||
# Complex ops
|
||||
def test_softmax(self): self._test_kernel_roundtrip(lambda T: T.empty(16).softmax())
|
||||
def test_layernorm(self): self._test_kernel_roundtrip(lambda T: T.empty(8, 8).layernorm())
|
||||
|
||||
# Memory patterns
|
||||
def test_contiguous(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4).permute(1, 0).contiguous())
|
||||
def test_reshape(self): self._test_kernel_roundtrip(lambda T: (T.empty(16) + 1).reshape(4, 4).contiguous())
|
||||
def test_expand(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 1).expand(4, 4).contiguous())
|
||||
|
||||
# Cast ops
|
||||
def test_cast_int(self): self._test_kernel_roundtrip(lambda T: T.empty(16).int().float())
|
||||
def test_cast_half(self): self._test_kernel_roundtrip(lambda T: T.empty(16).half().float())
|
||||
|
||||
# Comparison ops
|
||||
def test_cmp_lt(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) < T.empty(64)).where(T.empty(64), T.empty(64)))
|
||||
def test_where(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) > 0).where(T.empty(64), T.empty(64)))
|
||||
|
||||
# Fused ops
|
||||
def test_fma(self): self._test_kernel_roundtrip(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0])))
|
||||
|
||||
@unittest.skipUnless(os.environ.get("TEST_RDNA4"), "RDNA4 roundtrip tests require TEST_RDNA4=1 and gfx1200 hardware")
|
||||
class TestRoundtripRDNA4(TestRoundtripBase):
|
||||
"""Roundtrip tests for RDNA4 (gfx1200)."""
|
||||
mcpu, arch = 'gfx1200', 'rdna4'
|
||||
|
||||
def test_neg(self): self._test_kernel_roundtrip(lambda T: -T([1.0, -2.0, 3.0, -4.0]))
|
||||
def test_relu(self): self._test_kernel_roundtrip(lambda T: T([-1.0, 0.0, 1.0, 2.0]).relu())
|
||||
def test_exp(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).exp())
|
||||
def test_log(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0, 3.0]).log())
|
||||
def test_sin(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).sin())
|
||||
def test_sqrt(self): self._test_kernel_roundtrip(lambda T: T([1.0, 4.0, 9.0]).sqrt())
|
||||
def test_recip(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0, 4.0]).reciprocal())
|
||||
def test_add(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0]) + T([3.0, 4.0]))
|
||||
def test_sub(self): self._test_kernel_roundtrip(lambda T: T([5.0, 6.0]) - T([1.0, 2.0]))
|
||||
def test_mul(self): self._test_kernel_roundtrip(lambda T: T([2.0, 3.0]) * T([4.0, 5.0]))
|
||||
def test_div(self): self._test_kernel_roundtrip(lambda T: T([10.0, 20.0]) / T([2.0, 4.0]))
|
||||
def test_max_binary(self): self._test_kernel_roundtrip(lambda T: T([1.0, 5.0]).maximum(T([3.0, 2.0])))
|
||||
def test_sum_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).sum())
|
||||
def test_max_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).max())
|
||||
def test_mean_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(32).mean())
|
||||
def test_gemm_4x4(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4) @ T.empty(4, 4))
|
||||
def test_gemv(self): self._test_kernel_roundtrip(lambda T: T.empty(1, 16) @ T.empty(16, 16))
|
||||
def test_softmax(self): self._test_kernel_roundtrip(lambda T: T.empty(16).softmax())
|
||||
def test_layernorm(self): self._test_kernel_roundtrip(lambda T: T.empty(8, 8).layernorm())
|
||||
def test_contiguous(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4).permute(1, 0).contiguous())
|
||||
def test_reshape(self): self._test_kernel_roundtrip(lambda T: (T.empty(16) + 1).reshape(4, 4).contiguous())
|
||||
def test_expand(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 1).expand(4, 4).contiguous())
|
||||
def test_cast_int(self): self._test_kernel_roundtrip(lambda T: T.empty(16).int().float())
|
||||
def test_cast_half(self): self._test_kernel_roundtrip(lambda T: T.empty(16).half().float())
|
||||
def test_cmp_lt(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) < T.empty(64)).where(T.empty(64), T.empty(64)))
|
||||
def test_where(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) > 0).where(T.empty(64), T.empty(64)))
|
||||
def test_fma(self): self._test_kernel_roundtrip(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0])))
|
||||
|
||||
# Keep old class name for backwards compatibility
|
||||
TestTinygradKernelRoundtrip = TestRoundtripRDNA3
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue