Compare commits

...

3 commits

Author SHA1 Message Date
George Hotz
7a1190b729 Merge origin/master into only_reg_emu2 (keep branch's Reg-based approach) 2025-12-30 18:53:50 +00:00
George Hotz
433248c998 assembly/amd: only reg emu 2025-12-30 18:05:09 +00:00
George Hotz
7f139a934f assembly/amd: switch to Reg in pcode 2025-12-30 14:00:28 +00:00
4 changed files with 1751 additions and 11107 deletions

File diff suppressed because it is too large Load diff

View file

@ -1,10 +1,9 @@
# RDNA3 emulator - executes compiled pseudocode from AMD ISA PDF
# mypy: ignore-errors
from __future__ import annotations
import ctypes, os
import ctypes
from extra.assembly.amd.dsl import Inst, RawImm
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.pcode import _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.pcode import _f32, _i32, _sext, _f16, _i16, _f64, _i64, Reg, SliceProxy
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
from extra.assembly.amd.autogen.rdna3 import (
SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD, SrcEnum,
@ -18,7 +17,6 @@ VCC_LO, VCC_HI, NULL, EXEC_LO, EXEC_HI, SCC = SrcEnum.VCC_LO, SrcEnum.VCC_HI, Sr
# VOP3 ops that use 64-bit operands (and thus 64-bit literals when src is 255)
# Exception: V_LDEXP_F64 has 32-bit integer src1, so literal should NOT be 64-bit when src1=255
_VOP3_64BIT_OPS = {op.value for op in VOP3Op if op.name.endswith(('_F64', '_B64', '_I64', '_U64'))}
_VOPC_64BIT_OPS = {op.value for op in VOPCOp if op.name.endswith(('_F64', '_B64', '_I64', '_U64'))}
# Ops where src1 is 32-bit (exponent/shift amount) even though the op name suggests 64-bit
_VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value}
# Ops with 16-bit types in name (for source/dest handling)
@ -26,18 +24,12 @@ _VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value}
_VOP3_16BIT_OPS = {op for op in VOP3Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16')) and 'SAD' not in op.name}
_VOP1_16BIT_OPS = {op for op in VOP1Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
_VOP2_16BIT_OPS = {op for op in VOP2Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
_VOPC_16BIT_OPS = {op for op in VOPCOp if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
# CVT ops with 32/64-bit source (despite 16-bit in name)
_CVT_32_64_SRC_OPS = {op for op in VOP3Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))} | \
{op for op in VOP1Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))}
# CVT ops with 32-bit destination (convert FROM 16-bit TO 32-bit): V_CVT_F32_F16, V_CVT_I32_I16, V_CVT_U32_U16
_CVT_32_DST_OPS = {op for op in VOP3Op if op.name.startswith('V_CVT_') and any(s in op.name for s in ('F32_F16', 'I32_I16', 'U32_U16', 'I32_F16', 'U32_F16'))} | \
{op for op in VOP1Op if op.name.startswith('V_CVT_') and any(s in op.name for s in ('F32_F16', 'I32_I16', 'U32_U16', 'I32_F16', 'U32_F16'))}
# 16-bit dst ops (PACK has 32-bit dst despite F16 in name, CVT to 32-bit has 32-bit dst)
_VOP3_16BIT_DST_OPS = {op for op in _VOP3_16BIT_OPS if 'PACK' not in op.name} - _CVT_32_DST_OPS
_VOP1_16BIT_DST_OPS = {op for op in _VOP1_16BIT_OPS if 'PACK' not in op.name} - _CVT_32_DST_OPS
# VOP1 16-bit source ops (excluding CVT ops with 32/64-bit source) - for VOP1 e32, .h encoded in register index
_VOP1_16BIT_SRC_OPS = _VOP1_16BIT_OPS - _CVT_32_64_SRC_OPS
# 16-bit dst ops (PACK has 32-bit dst despite F16 in name)
_VOP3_16BIT_DST_OPS = {op for op in _VOP3_16BIT_OPS if 'PACK' not in op.name}
_VOP1_16BIT_DST_OPS = {op for op in _VOP1_16BIT_OPS if 'PACK' not in op.name}
# Inline constants for src operands 128-254. Build tables for f32, f16, and f64 formats.
import struct as _struct
@ -97,6 +89,36 @@ def _get_compiled() -> dict:
if _COMPILED is None: _COMPILED = get_compiled_functions()
return _COMPILED
def _run_pcode(fn, op_cls, op, s0, s1, s2, d0, scc, vcc, lane, exec_mask, vdst_idx):
"""Create Regs, run pseudocode, extract results."""
# Determine flags from op_cls and op.name
is_div_scale = 'DIV_SCALE' in op.name
is_64 = op.name.endswith(('_B64', '_I64', '_U64', '_F64')) or op.name in ('V_MAD_U64_U32', 'V_MAD_I64_I32')
is_cmp = op_cls.__name__ == 'VOPCOp' and not op.name.startswith('V_CMPX')
is_cmpx = op_cls.__name__ == 'VOPCOp' and op.name.startswith('V_CMPX')
has_sdst = op_cls.__name__ == 'VOP3SDOp'
# Create Regs - D0 gets s0 for DIV_SCALE (passthrough behavior)
S0, S1, S2 = Reg(s0), Reg(s1), Reg(s2)
D0, D1 = Reg(s0 if is_div_scale else d0), Reg(0)
SCC, VCC, EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
tmp = Reg(0)
# Call pseudocode
ret = fn(S0, S1, S2, D0, D1, SCC, VCC, EXEC, tmp, lane)
# Build result
result = {'d0': D0._val, 'scc': SCC._val & 1}
if has_sdst or VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1
if is_cmpx: result['exec_lane'] = (EXEC._val >> lane) & 1
elif EXEC._val != exec_mask: result['exec'] = EXEC._val
if is_cmp: result['vcc_lane'] = (D0._val >> lane) & 1
if is_64: result['d0_64'] = True
if D1._val: result['d1'] = D1._val & 1
# V_WRITELANE_B32 returns (wr_lane, value) directly
if ret is not None: result['vgpr_write'] = (ret[0], vdst_idx, ret[1])
return result
class WaveState:
__slots__ = ('sgpr', 'vgpr', 'scc', 'pc', 'literal', '_pend_sgpr')
def __init__(self):
@ -147,7 +169,21 @@ class WaveState:
for reg, val in self._pend_sgpr.items(): self.sgpr[reg] = val
self._pend_sgpr.clear()
# Instruction decode
def decode_format(word: int) -> tuple[type[Inst] | None, bool]:
hi2 = (word >> 30) & 0x3
if hi2 == 0b11:
enc = (word >> 26) & 0xf
if enc == 0b1101: return SMEM, True
if enc == 0b0101:
op = (word >> 16) & 0x3ff
return (VOP3SD, True) if op in (288, 289, 290, 764, 765, 766, 767, 768, 769, 770) else (VOP3, True)
return {0b0011: (VOP3P, True), 0b0110: (DS, True), 0b0111: (FLAT, True), 0b0010: (VOPD, True)}.get(enc, (None, True))
if hi2 == 0b10:
enc = (word >> 23) & 0x7f
return {0b1111101: (SOP1, False), 0b1111110: (SOPC, False), 0b1111111: (SOPP, False)}.get(enc, (SOPK, False) if ((word >> 28) & 0xf) == 0b1011 else (SOP2, False))
enc = (word >> 25) & 0x7f
return (VOPC, False) if enc == 0b0111110 else (VOP1, False) if enc == 0b0111111 else (VOP2, False)
def _unwrap(v) -> int: return v.val if isinstance(v, RawImm) else v.value if hasattr(v, 'value') else v
@ -155,10 +191,10 @@ def decode_program(data: bytes) -> Program:
result: Program = {}
i = 0
while i < len(data):
try: inst_class = detect_format(data[i:])
except ValueError: break # stop at invalid instruction (padding/metadata after code)
word = int.from_bytes(data[i:i+4], 'little')
inst_class, is_64 = decode_format(word)
if inst_class is None: i += 4; continue
base_size = inst_class._size()
base_size = 8 if is_64 else 4
# Pass enough data for potential 64-bit literal (base + 8 bytes max)
inst = inst_class.from_bytes(data[i:i+base_size+8])
for name, val in inst._values.items(): setattr(inst, name, _unwrap(val))
@ -173,7 +209,7 @@ def decode_program(data: bytes) -> Program:
# Exception: some ops have mixed src sizes (e.g., V_LDEXP_F64 has 32-bit src1)
op_val = inst._values.get('op')
if hasattr(op_val, 'value'): op_val = op_val.value
is_64bit = (inst_class is VOP3 and op_val in _VOP3_64BIT_OPS) or (inst_class is VOPC and op_val in _VOPC_64BIT_OPS)
is_64bit = inst_class is VOP3 and op_val in _VOP3_64BIT_OPS
# Don't treat literal as 64-bit if the op has 32-bit src1 and src1 is the literal
if is_64bit and op_val in _VOP3_64BIT_OPS_32BIT_SRC1 and getattr(inst, 'src1', None) == 255:
is_64bit = False
@ -193,11 +229,21 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
compiled = _get_compiled()
inst_type = type(inst)
# SOPP: special cases for control flow that has no pseudocode
# SOPP: control flow (not ALU)
if inst_type is SOPP:
op = inst.op
if op == SOPPOp.S_ENDPGM: return -1
if op == SOPPOp.S_BARRIER: return -2
if op == SOPPOp.S_BRANCH: return _sext(inst.simm16, 16)
if op == SOPPOp.S_CBRANCH_SCC0: return _sext(inst.simm16, 16) if st.scc == 0 else 0
if op == SOPPOp.S_CBRANCH_SCC1: return _sext(inst.simm16, 16) if st.scc == 1 else 0
if op == SOPPOp.S_CBRANCH_VCCZ: return _sext(inst.simm16, 16) if (st.vcc & 0xffffffff) == 0 else 0
if op == SOPPOp.S_CBRANCH_VCCNZ: return _sext(inst.simm16, 16) if (st.vcc & 0xffffffff) != 0 else 0
if op == SOPPOp.S_CBRANCH_EXECZ: return _sext(inst.simm16, 16) if st.exec_mask == 0 else 0
if op == SOPPOp.S_CBRANCH_EXECNZ: return _sext(inst.simm16, 16) if st.exec_mask != 0 else 0
# Valid SOPP range is 0-61 (max defined opcode); anything above is invalid
if op > 61: raise NotImplementedError(f"Invalid SOPP opcode {op}")
return 0 # waits, hints, nops
# SMEM: memory loads (not ALU)
if inst_type is SMEM:
@ -207,53 +253,49 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
for i in range(cnt): st.wsgpr(inst.sdata + i, mem_read((addr + i * 4) & 0xffffffffffffffff, 4))
return 0
# SOP1: special handling for ops not in pseudocode
if inst_type is SOP1:
op = SOP1Op(inst.op)
# S_GETPC_B64: Get program counter (PC is stored as byte offset, convert from words)
if op == SOP1Op.S_GETPC_B64:
pc_bytes = st.pc * 4 # PC is in words, convert to bytes
st.wsgpr64(inst.sdst, pc_bytes)
return 0
# S_SETPC_B64: Set program counter to source value (indirect jump)
# Returns delta such that st.pc + inst_words + delta = target_words
if op == SOP1Op.S_SETPC_B64:
target_bytes = st.rsrc64(inst.ssrc0, 0)
target_words = target_bytes // 4
inst_words = 1 # SOP1 is always 1 word
return target_words - st.pc - inst_words
# Get op enum and lookup compiled function
if inst_type is SOP1: op_cls, ssrc0, sdst = SOP1Op, inst.ssrc0, inst.sdst
elif inst_type is SOP2: op_cls, ssrc0, sdst = SOP2Op, inst.ssrc0, inst.sdst
elif inst_type is SOPC: op_cls, ssrc0, sdst = SOPCOp, inst.ssrc0, None
elif inst_type is SOPK: op_cls, ssrc0, sdst = SOPKOp, inst.sdst, inst.sdst # sdst is both src and dst
elif inst_type is SOPP: op_cls, ssrc0, sdst = SOPPOp, None, None
else: raise NotImplementedError(f"Unknown scalar type {inst_type}")
# SOPP has gaps in the opcode enum - treat unknown opcodes as no-ops
try: op = op_cls(inst.op)
except ValueError:
if inst_type is SOPP: return 0
raise
op = op_cls(inst.op)
fn = compiled.get(op_cls, {}).get(op)
if fn is None:
# SOPP instructions without pseudocode (waits, hints, nops) are no-ops
if inst_type is SOPP: return 0
raise NotImplementedError(f"{op.name} not in pseudocode")
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
# Build context - handle 64-bit ops that need 64-bit source reads
# 64-bit source ops: name ends with _B64, _I64, _U64 or contains _U64, _I64 before last underscore
# Read sources - 64-bit ops need 64-bit source reads
is_64bit_s0 = op.name.endswith(('_B64', '_I64', '_U64')) or '_U64_' in op.name or '_I64_' in op.name
is_64bit_s0s1 = op_cls is SOPCOp and op in (SOPCOp.S_CMP_EQ_U64, SOPCOp.S_CMP_LG_U64)
s0 = st.rsrc64(ssrc0, 0) if is_64bit_s0 or is_64bit_s0s1 else (st.rsrc(ssrc0, 0) if inst_type not in (SOPK, SOPP) else (st.rsgpr(inst.sdst) if inst_type is SOPK else 0))
s0 = st.rsrc64(ssrc0, 0) if is_64bit_s0 or is_64bit_s0s1 else (st.rsrc(ssrc0, 0) if inst_type != SOPK else st.rsgpr(inst.sdst))
is_64bit_sop2 = is_64bit_s0 and inst_type is SOP2
s1 = st.rsrc64(inst.ssrc1, 0) if (is_64bit_sop2 or is_64bit_s0s1) else (st.rsrc(inst.ssrc1, 0) if inst_type in (SOP2, SOPC) else inst.simm16 if inst_type is SOPK else 0)
s1 = st.rsrc64(inst.ssrc1, 0) if (is_64bit_sop2 or is_64bit_s0s1) else (st.rsrc(inst.ssrc1, 0) if inst_type in (SOP2, SOPC) else 0)
s2 = inst.simm16 if inst_type is SOPK else 0 # SOPK: 16-bit immediate passed as S2
d0 = st.rsgpr64(sdst) if (is_64bit_s0 or is_64bit_s0s1) and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
exec_mask = st.exec_mask
literal = inst.simm16 if inst_type in (SOPK, SOPP) else st.literal
# Execute compiled function - pass PC in bytes for instructions that need it
pc_bytes = st.pc * 4
result = fn(s0, s1, 0, d0, st.scc, st.vcc, 0, exec_mask, literal, None, {}, pc=pc_bytes)
# Apply results
# Execute and apply results
result = _run_pcode(fn, op_cls, op, s0, s1, s2, d0, st.scc, st.vcc, 0, st.exec_mask, 0)
if sdst is not None:
if result.get('d0_64'):
st.wsgpr64(sdst, result['d0'])
else:
st.wsgpr(sdst, result['d0'])
if 'scc' in result: st.scc = result['scc']
if result.get('d0_64'): st.wsgpr64(sdst, result['d0'])
else: st.wsgpr(sdst, result['d0'])
st.scc = result['scc']
if 'exec' in result: st.exec_mask = result['exec']
if 'new_pc' in result:
# Convert absolute byte address to word delta
# new_pc is where we want to go, st.pc is current position, inst._words will be added after
new_pc_words = result['new_pc'] // 4
return new_pc_words - st.pc - 1 # -1 because emulator adds inst_words (1 for scalar)
return 0
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = None) -> None:
@ -304,19 +346,22 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
sx0, sx1 = st.rsrc(inst.srcx0, lane), V[inst.vsrcx1]
sy0, sy1 = st.rsrc(inst.srcy0, lane), V[inst.vsrcy1]
dx0, dy0 = V[inst.vdstx], V[vdsty]
# Execute X op
res_x = None
# FMAAK/FMAMK in VOPD use literal as S2
literal = getattr(inst, '_literal', None) or 0
res_x = res_y = None
if (op_x := _VOPD_TO_VOP.get(inst.opx)):
if (fn_x := compiled.get(type(op_x), {}).get(op_x)):
res_x = fn_x(sx0, sx1, 0, dx0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})
# Execute Y op
res_y = None
if (fn := compiled.get(type(op_x), {}).get(op_x)):
# opx 1=FMAMK, 2=FMAAK use literal
sx2 = literal if inst.opx in (VOPDOp.V_DUAL_FMAMK_F32, VOPDOp.V_DUAL_FMAAK_F32) else 0
res_x = _run_pcode(fn, type(op_x), op_x, sx0, sx1, sx2, dx0, st.scc, st.vcc, lane, st.exec_mask, 0)
if (op_y := _VOPD_TO_VOP.get(inst.opy)):
if (fn_y := compiled.get(type(op_y), {}).get(op_y)):
res_y = fn_y(sy0, sy1, 0, dy0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})
if (fn := compiled.get(type(op_y), {}).get(op_y)):
# opy 1=FMAMK, 2=FMAAK use literal
sy2 = literal if inst.opy in (VOPDOp.V_DUAL_FMAMK_F32, VOPDOp.V_DUAL_FMAAK_F32) else 0
res_y = _run_pcode(fn, type(op_y), op_y, sy0, sy1, sy2, dy0, st.scc, st.vcc, lane, st.exec_mask, 0)
# Write results after both ops complete
if res_x is not None: V[inst.vdstx] = res_x['d0']
if res_y is not None: V[vdsty] = res_y['d0']
if res_x: V[inst.vdstx] = res_x['d0']
if res_y: V[vdsty] = res_y['d0']
return
# VOP3SD: has extra scalar dest for carry output
@ -324,30 +369,17 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
op = VOP3SDOp(inst.op)
fn = compiled.get(VOP3SDOp, {}).get(op)
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
# VOP3SD has both 32-bit ops (V_ADD_CO_CI_U32, etc.) and 64-bit ops (V_DIV_SCALE_F64, V_MAD_U64_U32, etc.)
div_scale_64_ops = (VOP3SDOp.V_DIV_SCALE_F64,)
s0, s1, s2 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane)
# For 64-bit src2 ops (V_MAD_U64_U32, V_MAD_I64_I32), read from consecutive registers
mad64_ops = (VOP3SDOp.V_MAD_U64_U32, VOP3SDOp.V_MAD_I64_I32)
if op in div_scale_64_ops:
# V_DIV_SCALE_F64: all sources are 64-bit
s0, s1, s2 = st.rsrc64(inst.src0, lane), st.rsrc64(inst.src1, lane), st.rsrc64(inst.src2, lane)
elif op in mad64_ops:
# V_MAD_U64_U32, V_MAD_I64_I32: src0/src1 are 32-bit, src2 is 64-bit
s0, s1 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane)
if inst.src2 >= 256: # VGPR
s2 = V[inst.src2 - 256] | (V[inst.src2 - 256 + 1] << 32)
else: # SGPR - read 64-bit from consecutive SGPRs
s2 = st.rsgpr64(inst.src2)
else:
# Default: 32-bit sources
s0, s1, s2 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane)
if op in mad64_ops:
s2 = (V[inst.src2 - 256] | (V[inst.src2 - 256 + 1] << 32)) if inst.src2 >= 256 else st.rsgpr64(inst.src2)
d0 = V[inst.vdst]
# For carry-in operations (V_*_CO_CI_*), src2 register contains the carry bitmask (not VCC).
# For carry-in ops (V_*_CO_CI_*), src2 register contains the carry bitmask (not VCC).
# The pseudocode uses VCC but in VOP3SD encoding, the actual carry source is inst.src2.
# We pass the src2 register value as 'vcc' to the interpreter so it reads the correct carry.
carry_ops = (VOP3SDOp.V_ADD_CO_CI_U32, VOP3SDOp.V_SUB_CO_CI_U32, VOP3SDOp.V_SUBREV_CO_CI_U32)
vcc_for_exec = st.rsgpr64(inst.src2) if op in carry_ops else st.vcc
result = fn(s0, s1, s2, d0, st.scc, vcc_for_exec, lane, st.exec_mask, st.literal, None, {})
# Write result - handle 64-bit destinations
result = _run_pcode(fn, VOP3SDOp, op, s0, s1, s2, d0, st.scc, vcc_for_exec, lane, st.exec_mask, inst.vdst)
if result.get('d0_64'):
V[inst.vdst] = result['d0'] & 0xffffffff
V[inst.vdst + 1] = (result['d0'] >> 32) & 0xffffffff
@ -360,37 +392,54 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
# Get op enum and sources (None means "no source" for that operand)
# vop1_dst_hi/vop2_dst_hi: for VOP1/VOP2 16-bit dst ops, bit 7 of vdst indicates .h (high 16-bit) destination
vop1_dst_hi, vop2_dst_hi = False, False
if inst_type is VOP1:
if inst.op == VOP1Op.V_NOP: return
op_cls, op, src0, src1, src2 = VOP1Op, VOP1Op(inst.op), inst.src0, None, None
# For 16-bit dst ops, vdst encodes .h in bit 7
if op in _VOP1_16BIT_DST_OPS:
vop1_dst_hi = (inst.vdst & 0x80) != 0
vdst = inst.vdst & 0x7f
else:
vdst = inst.vdst
# V_READFIRSTLANE_B32: read from first active lane's VGPR -> SGPR (not in pseudocode - needs cross-lane access)
if inst.op == VOP1Op.V_READFIRSTLANE_B32:
first_lane = (st.exec_mask & -st.exec_mask).bit_length() - 1 if st.exec_mask else 0
vgpr_idx = inst.src0 - 256 if inst.src0 >= 256 else inst.src0 # VGPR index
st.wsgpr(inst.vdst, st.vgpr[first_lane][vgpr_idx])
return
op_cls, op, src0, src1, src2, vdst = VOP1Op, VOP1Op(inst.op), inst.src0, None, None, inst.vdst
elif inst_type is VOP2:
op_cls, op, src0, src1, src2 = VOP2Op, VOP2Op(inst.op), inst.src0, inst.vsrc1 + 256, None
# For 16-bit dst ops, vdst encodes .h in bit 7
if op in _VOP2_16BIT_OPS:
vop2_dst_hi = (inst.vdst & 0x80) != 0
vdst = inst.vdst & 0x7f
else:
vdst = inst.vdst
op_cls, op = VOP2Op, VOP2Op(inst.op)
# FMAAK/FMAMK use inline literal constant as S2
literal = getattr(inst, '_literal', None)
src0, src1, src2, vdst = inst.src0, inst.vsrc1 + 256, literal, inst.vdst
elif inst_type is VOP3:
# VOP3 ops 0-255 are VOPC comparisons encoded as VOP3 (use VOPCOp pseudocode)
if inst.op < 256:
op_cls, op, src0, src1, src2, vdst = VOPCOp, VOPCOp(inst.op), inst.src0, inst.src1, None, inst.vdst
else:
op_cls, op, src0, src1, src2, vdst = VOP3Op, VOP3Op(inst.op), inst.src0, inst.src1, inst.src2, inst.vdst
# V_READFIRSTLANE_B32 in VOP3 encoding - same as VOP1 but with VOP3 format
if op == VOP3Op.V_READFIRSTLANE_B32:
first_lane = (st.exec_mask & -st.exec_mask).bit_length() - 1 if st.exec_mask else 0
vgpr_idx = inst.src0 - 256 if inst.src0 >= 256 else inst.src0
st.wsgpr(inst.vdst, st.vgpr[first_lane][vgpr_idx])
return
# V_READLANE_B32: read from specific lane's VGPR -> SGPR (lane specified in src1)
if op == VOP3Op.V_READLANE_B32:
read_lane = st.rsrc(inst.src1, lane) & 0x1f # Lane to read from (5 bits)
vgpr_idx = inst.src0 - 256 if inst.src0 >= 256 else inst.src0
st.wsgpr(inst.vdst, st.vgpr[read_lane][vgpr_idx])
return
# V_PERM_B32: byte permutation - not in pseudocode PDF, implement directly
# D0[byte_i] = selector[byte_i] < 8 ? {src0, src1}[selector[byte_i]] : (selector[byte_i] >= 0xD ? 0xFF : 0x00)
if op == VOP3Op.V_PERM_B32:
s0, s1, s2 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane)
# Combine src1 and src0 into 8-byte value: src1 is bytes 0-3, src0 is bytes 4-7
combined = (s1 & 0xffffffff) | ((s0 & 0xffffffff) << 32)
result = 0
for i in range(4): # 4 result bytes
sel = (s2 >> (i * 8)) & 0xff # byte selector for this position
if sel <= 7: result |= (((combined >> (sel * 8)) & 0xff) << (i * 8)) # select byte from combined
elif sel >= 0xd: result |= (0xff << (i * 8)) # 0xD-0xF: constant 0xFF
# else 0x8-0xC: constant 0x00 (already 0)
V[vdst] = result & 0xffffffff
return
elif inst_type is VOPC:
op = VOPCOp(inst.op)
# For 16-bit VOPC, vsrc1 uses same encoding as VOP2 16-bit: bit 7 selects hi(1) or lo(0) half
# vsrc1 field is 8 bits: [6:0] = VGPR index, [7] = hi flag
src1 = inst.vsrc1 + 256 # convert to standard VGPR encoding (256 + vgpr_idx)
op_cls, src0, src2, vdst = VOPCOp, inst.src0, None, VCC_LO
op_cls, op, src0, src1, src2, vdst = VOPCOp, VOPCOp(inst.op), inst.src0, inst.vsrc1 + 256, None, VCC_LO
elif inst_type is VOP3P:
# VOP3P: Packed 16-bit operations using compiled functions
op = VOP3POp(inst.op)
@ -399,44 +448,26 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
if lane == 0: # Only execute once per wave, write results for all lanes
exec_wmma(st, inst, op)
return
# V_FMA_MIX: Mixed precision FMA - inputs can be f16 or f32 controlled by opsel_hi/opsel_hi2
# opsel_hi[0]: src0 is f32 (0) or f16 from hi bits (1)
# opsel_hi[1]: src1 is f32 (0) or f16 from hi bits (1)
# opsel_hi2: src2 is f32 (0) or f16 from hi bits (1)
# opsel[i]: when source is f16, use lo (0) or hi (1) 16 bits - BUT for V_FMA_MIX, opsel selects lo/hi when opsel_hi=1
# neg_hi[i]: abs modifier for source i (reuses neg_hi field for abs in V_FMA_MIX)
# V_FMA_MIX: Mixed precision FMA - inputs can be f16 or f32 controlled by opsel
if op in (VOP3POp.V_FMA_MIX_F32, VOP3POp.V_FMA_MIXLO_F16, VOP3POp.V_FMA_MIXHI_F16):
opsel = getattr(inst, 'opsel', 0)
opsel_hi = getattr(inst, 'opsel_hi', 0)
opsel_hi2 = getattr(inst, 'opsel_hi2', 0)
neg = getattr(inst, 'neg', 0)
abs_ = getattr(inst, 'neg_hi', 0) # neg_hi field is reused as abs for V_FMA_MIX
neg_hi = getattr(inst, 'neg_hi', 0)
vdst = inst.vdst
# Read raw 32-bit values
# Read raw 32-bit values - for V_FMA_MIX, sources can be either f32 or f16
s0_raw = st.rsrc(inst.src0, lane)
s1_raw = st.rsrc(inst.src1, lane)
s2_raw = st.rsrc(inst.src2, lane) if inst.src2 is not None else 0
# Decode sources based on opsel_hi (controls f32 vs f16) and opsel (controls which half for f16)
# src0: opsel_hi[0]=1 means f16, opsel[0] selects hi(1) or lo(0) half
if opsel_hi & 1:
s0 = _f16((s0_raw >> 16) & 0xffff) if (opsel & 1) else _f16(s0_raw & 0xffff)
else:
s0 = _f32(s0_raw)
# src1: opsel_hi[1]=1 means f16, opsel[1] selects hi(1) or lo(0) half
if opsel_hi & 2:
s1 = _f16((s1_raw >> 16) & 0xffff) if (opsel & 2) else _f16(s1_raw & 0xffff)
else:
s1 = _f32(s1_raw)
# src2: opsel_hi2=1 means f16, opsel[2] selects hi(1) or lo(0) half
if opsel_hi2:
s2 = _f16((s2_raw >> 16) & 0xffff) if (opsel & 4) else _f16(s2_raw & 0xffff)
else:
s2 = _f32(s2_raw)
# Apply abs modifiers (abs_ field reuses neg_hi position)
if abs_ & 1: s0 = abs(s0)
if abs_ & 2: s1 = abs(s1)
if abs_ & 4: s2 = abs(s2)
# Apply neg modifiers
# opsel[i]=0: use as f32, opsel[i]=1: use hi f16 as f32
# For src0: opsel[0], for src1: opsel[1], for src2: opsel[2]
if opsel & 1: s0 = _f16((s0_raw >> 16) & 0xffff) # hi f16 -> f32
else: s0 = _f32(s0_raw) # use as f32
if opsel & 2: s1 = _f16((s1_raw >> 16) & 0xffff)
else: s1 = _f32(s1_raw)
if opsel & 4: s2 = _f16((s2_raw >> 16) & 0xffff)
else: s2 = _f32(s2_raw)
# Apply neg modifiers (for f32 values)
if neg & 1: s0 = -s0
if neg & 2: s1 = -s1
if neg & 4: s2 = -s2
@ -485,10 +516,10 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
s0 = (s0_hi << 16) | s0_lo
s1 = (s1_hi << 16) | s1_lo
s2 = (s2_hi << 16) | s2_lo
op_cls, vdst = VOP3POp, inst.vdst
fn = compiled.get(op_cls, {}).get(op)
vdst = inst.vdst
fn = compiled.get(VOP3POp, {}).get(op)
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
result = fn(s0, s1, s2, 0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})
result = _run_pcode(fn, VOP3POp, op, s0, s1, s2, 0, st.scc, st.vcc, lane, st.exec_mask, vdst)
st.vgpr[lane][vdst] = result['d0'] & 0xffffffff
return
else: raise NotImplementedError(f"Unknown vector type {inst_type}")
@ -510,17 +541,13 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
# Determine if sources are 64-bit based on instruction type
# For 64-bit shift ops: src0 is 32-bit (shift amount), src1 is 64-bit (value to shift)
# For V_LDEXP_F64: src0 is 64-bit float, src1 is 32-bit integer exponent
# For most other _B64/_I64/_U64/_F64 ops: all sources are 64-bit
is_64bit_op = op.name.endswith(('_B64', '_I64', '_U64', '_F64'))
# V_LDEXP_F64, V_TRIG_PREOP_F64, V_CMP_CLASS_F64, V_CMPX_CLASS_F64: src0 is 64-bit, src1 is 32-bit
is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64, VOP3Op.V_TRIG_PREOP_F64, VOP3Op.V_CMP_CLASS_F64, VOP3Op.V_CMPX_CLASS_F64,
VOPCOp.V_CMP_CLASS_F64, VOPCOp.V_CMPX_CLASS_F64)
is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64,)
is_shift_64 = op in (VOP3Op.V_LSHLREV_B64, VOP3Op.V_LSHRREV_B64, VOP3Op.V_ASHRREV_I64)
# 16-bit source ops: use precomputed sets instead of string checks
# Note: must check op_cls to avoid cross-enum value collisions
is_16bit_src = op_cls is VOP3Op and op in _VOP3_16BIT_OPS and op not in _CVT_32_64_SRC_OPS
# VOP2 16-bit ops use f16 inline constants for src0 (vsrc1 is always a VGPR, no inline constants)
is_vop2_16bit = op_cls is VOP2Op and op in _VOP2_16BIT_OPS
is_vop2_16bit = op_cls is VOP2Op and op in _VOP2_16BIT_OPS # VOP2 16-bit ops use f16 inline constants
if is_shift_64:
s0 = mod_src(st.rsrc(src0, lane), 0) # shift amount is 32-bit
@ -528,107 +555,37 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
elif is_ldexp_64:
s0 = mod_src64(st.rsrc64(src0, lane), 0) # mantissa is 64-bit float
# src1 is 32-bit int. For 64-bit ops (like V_CMP_CLASS_F64), the literal is stored shifted left by 32.
# For V_LDEXP_F64/V_TRIG_PREOP_F64, _is_64bit_op() returns False so literal is stored as-is.
s1_raw = st.rsrc(src1, lane) if src1 is not None else 0
# Only shift if src1 is literal AND this is a true 64-bit op (V_CMP_CLASS ops, not LDEXP/TRIG_PREOP)
is_class_op = op in (VOP3Op.V_CMP_CLASS_F64, VOP3Op.V_CMPX_CLASS_F64, VOPCOp.V_CMP_CLASS_F64, VOPCOp.V_CMPX_CLASS_F64)
s1 = mod_src((s1_raw >> 32) if src1 == 255 and is_class_op else s1_raw, 1)
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0 # exponent is 32-bit int
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
elif is_64bit_op:
# 64-bit ops: apply neg/abs modifiers using f64 interpretation for float ops
s0 = mod_src64(st.rsrc64(src0, lane), 0)
s1 = mod_src64(st.rsrc64(src1, lane), 1) if src1 is not None else 0
s2 = mod_src64(st.rsrc64(src2, lane), 2) if src2 is not None else 0
elif is_16bit_src:
# For 16-bit source ops, opsel bits select which half to use
# Inline constants (128-254) must use f16 encoding, not f32
def rsrc_16bit(src, lane): return st.rsrc_f16(src, lane) if 128 <= src < 255 else st.rsrc(src, lane)
s0_raw = rsrc_16bit(src0, lane)
s1_raw = rsrc_16bit(src1, lane) if src1 is not None else 0
s2_raw = rsrc_16bit(src2, lane) if src2 is not None else 0
# opsel[0] selects hi(1) or lo(0) for src0, opsel[1] for src1, opsel[2] for src2
s0_raw = mod_src(st.rsrc(src0, lane), 0)
s1_raw = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
s2_raw = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
s0 = ((s0_raw >> 16) & 0xffff) if (opsel & 1) else (s0_raw & 0xffff)
s1 = ((s1_raw >> 16) & 0xffff) if (opsel & 2) else (s1_raw & 0xffff)
s2 = ((s2_raw >> 16) & 0xffff) if (opsel & 4) else (s2_raw & 0xffff)
# Apply abs/neg modifiers as f16 operations (toggle sign bit 15)
if abs_ & 1: s0 &= 0x7fff
if abs_ & 2: s1 &= 0x7fff
if abs_ & 4: s2 &= 0x7fff
if neg & 1: s0 ^= 0x8000
if neg & 2: s1 ^= 0x8000
if neg & 4: s2 ^= 0x8000
elif is_vop2_16bit:
# VOP2 16-bit ops: src0 uses f16 inline constants, or VGPR where v128+ = hi half of v0-v127
# RDNA3 encoding: for VGPRs, bit 7 of VGPR index (src0-256) selects hi(1) or lo(0) half
if src0 >= 256: # VGPR
src0_hi = (src0 - 256) & 0x80 != 0
src0_masked = ((src0 - 256) & 0x7f) + 256 # mask out hi bit to get actual VGPR
s0_raw = mod_src(st.rsrc(src0_masked, lane), 0)
s0 = ((s0_raw >> 16) & 0xffff) if src0_hi else (s0_raw & 0xffff)
else: # SGPR or inline constant
s0_raw = mod_src(st.rsrc_f16(src0, lane), 0)
s0 = s0_raw & 0xffff
# vsrc1: .h suffix encoded in bit 7 of VGPR index (src1 = 256 + vgpr_idx + 0x80 if hi)
if src1 is not None:
src1_hi = (src1 - 256) & 0x80 != 0
src1_masked = ((src1 - 256) & 0x7f) + 256
s1_raw = mod_src(st.rsrc(src1_masked, lane), 1)
s1 = ((s1_raw >> 16) & 0xffff) if src1_hi else (s1_raw & 0xffff)
else:
s1 = 0
s0 = mod_src(st.rsrc_f16(src0, lane), 0)
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
elif op_cls is VOP1Op and op in _VOP1_16BIT_SRC_OPS:
# VOP1 16-bit source ops: .h encoded in bit 7 of VGPR index (src0 >= 384 means hi half)
# For VGPRs: src0 = 256 + vgpr_idx + (0x80 if hi else 0), so bit 7 of (src0-256) is the hi flag
src0_hi = src0 >= 256 and ((src0 - 256) & 0x80) != 0
src0_masked = ((src0 - 256) & 0x7f) + 256 if src0 >= 256 else src0 # mask out hi bit for VGPR
s0_raw = mod_src(st.rsrc(src0_masked, lane), 0)
s0 = ((s0_raw >> 16) & 0xffff) if src0_hi else (s0_raw & 0xffff)
s1, s2 = 0, 0
elif op_cls is VOPCOp and op in _VOPC_16BIT_OPS:
# VOPC 16-bit ops: src0 and vsrc1 use same encoding as VOP2 16-bit
# For VGPRs, bit 7 of VGPR index selects hi(1) or lo(0) half
if src0 >= 256: # VGPR
src0_hi = (src0 - 256) & 0x80 != 0
src0_masked = ((src0 - 256) & 0x7f) + 256
s0_raw = mod_src(st.rsrc(src0_masked, lane), 0)
s0 = ((s0_raw >> 16) & 0xffff) if src0_hi else (s0_raw & 0xffff)
else: # SGPR or inline constant
s0_raw = mod_src(st.rsrc_f16(src0, lane), 0)
s0 = s0_raw & 0xffff
# vsrc1: bit 7 of VGPR index selects hi(1) or lo(0) half
if src1 is not None:
if src1 >= 256: # VGPR - use hi/lo encoding
src1_hi = (src1 - 256) & 0x80 != 0
src1_masked = ((src1 - 256) & 0x7f) + 256
s1_raw = mod_src(st.rsrc(src1_masked, lane), 1)
s1 = ((s1_raw >> 16) & 0xffff) if src1_hi else (s1_raw & 0xffff)
else: # SGPR or inline constant - read as 32-bit, use low 16 bits
s1_raw = mod_src(st.rsrc(src1, lane), 1)
s1 = s1_raw & 0xffffffff # V_CMP_CLASS uses full 32-bit mask
else:
s1 = 0
s2 = 0
else:
s0 = mod_src(st.rsrc(src0, lane), 0)
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
# For VOP2 16-bit ops (like V_FMAC_F16), the destination is used as an accumulator.
# The pseudocode reads D0.f16 from low 16 bits, so we need to shift hi->lo when vop2_dst_hi is True.
if is_vop2_16bit:
d0 = ((V[vdst] >> 16) & 0xffff) if vop2_dst_hi else (V[vdst] & 0xffff)
else:
d0 = V[vdst] if not is_64bit_op else (V[vdst] | (V[vdst + 1] << 32))
# src2 can be a register index OR a raw literal value (for FMAAK/FMAMK)
# If src2 > 511, it's a raw literal value, not a register index
s2 = src2 if src2 is not None and src2 > 511 else (mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0)
d0 = V[vdst] if not is_64bit_op else (V[vdst] | (V[vdst + 1] << 32))
# V_CNDMASK_B32/B16: VOP3 encoding uses src2 as mask (not VCC); VOP2 uses VCC implicitly
# Pass the correct mask as vcc to the function so pseudocode VCC.u64[laneId] works correctly
vcc_for_fn = st.rsgpr64(src2) if op in (VOP3Op.V_CNDMASK_B32, VOP3Op.V_CNDMASK_B16) and inst_type is VOP3 and src2 is not None and src2 < 256 else st.vcc
# V_CNDMASK_B32: VOP3 encoding uses src2 as mask (not VCC); VOP2 uses VCC implicitly
vcc_for_fn = st.rsgpr64(src2) if op in (VOP3Op.V_CNDMASK_B32,) and inst_type is VOP3 and src2 is not None and src2 < 256 else st.vcc
# Execute compiled function - pass src0_idx and vdst_idx for lane instructions
# For VGPR access: src0 index is the VGPR number (src0 - 256 if VGPR, else src0 for SGPR)
src0_idx = (src0 - 256) if src0 is not None and src0 >= 256 else (src0 if src0 is not None else 0)
result = fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, st.literal, st.vgpr, {}, src0_idx, vdst)
# Execute pseudocode
result = _run_pcode(fn, op_cls, op, s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, vdst)
# Apply results
if 'vgpr_write' in result:
@ -636,8 +593,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
wr_lane, wr_idx, wr_val = result['vgpr_write']
st.vgpr[wr_lane][wr_idx] = wr_val
if 'vcc_lane' in result:
# VOP2 carry instructions (V_ADD_CO_CI_U32, V_SUB_CO_CI_U32, V_SUBREV_CO_CI_U32) write carry to VCC implicitly
# VOPC and VOP3-encoded VOPC write to vdst (which is VCC_LO for VOPC, inst.sdst for VOP3)
# VOP2 carry instructions write carry to VCC implicitly; VOPC writes to vdst
vcc_dst = VCC_LO if op_cls is VOP2Op and op in (VOP2Op.V_ADD_CO_CI_U32, VOP2Op.V_SUB_CO_CI_U32, VOP2Op.V_SUBREV_CO_CI_U32) else vdst
st.pend_sgpr_lane(vcc_dst, lane, result['vcc_lane'])
if 'exec_lane' in result:
@ -645,35 +601,18 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
st.pend_sgpr_lane(EXEC_LO, lane, result['exec_lane'])
if 'd0' in result and op_cls not in (VOPCOp,) and 'vgpr_write' not in result:
# V_READFIRSTLANE_B32 and V_READLANE_B32 write to SGPR, not VGPR
# V_WRITELANE_B32 uses vgpr_write for cross-lane writes, don't overwrite with d0
writes_to_sgpr = op in (VOP1Op.V_READFIRSTLANE_B32,) or \
(op_cls is VOP3Op and op in (VOP3Op.V_READFIRSTLANE_B32, VOP3Op.V_READLANE_B32))
# Check for 16-bit destination ops (opsel[3] controls hi/lo write)
# Must check op_cls to avoid cross-enum value collisions (e.g., VOP1Op.V_MOV_B32=1 vs VOP3Op.V_CMP_LT_F16=1)
is_16bit_dst = (op_cls is VOP3Op and op in _VOP3_16BIT_DST_OPS) or (op_cls is VOP1Op and op in _VOP1_16BIT_DST_OPS)
is_16bit_dst = op in _VOP3_16BIT_DST_OPS or op in _VOP1_16BIT_DST_OPS
if writes_to_sgpr:
st.wsgpr(vdst, result['d0'] & 0xffffffff)
elif result.get('d0_64'):
elif result.get('d0_64') or is_64bit_op:
V[vdst] = result['d0'] & 0xffffffff
V[vdst + 1] = (result['d0'] >> 32) & 0xffffffff
elif is_16bit_dst and inst_type is VOP3:
# VOP3 16-bit ops: opsel[3] (bit 3 of opsel field) controls hi/lo destination
if opsel & 8: # opsel[3] = 1: write to high 16 bits
V[vdst] = (V[vdst] & 0x0000ffff) | ((result['d0'] & 0xffff) << 16)
else: # opsel[3] = 0: write to low 16 bits
V[vdst] = (V[vdst] & 0xffff0000) | (result['d0'] & 0xffff)
elif is_16bit_dst and inst_type is VOP1:
# VOP1 16-bit ops: .h suffix encoded in bit 7 of vdst (extracted as vop1_dst_hi)
if vop1_dst_hi: # .h: write to high 16 bits
V[vdst] = (V[vdst] & 0x0000ffff) | ((result['d0'] & 0xffff) << 16)
else: # .l: write to low 16 bits
V[vdst] = (V[vdst] & 0xffff0000) | (result['d0'] & 0xffff)
elif is_vop2_16bit:
# VOP2 16-bit ops: .h suffix encoded in bit 7 of vdst (extracted as vop2_dst_hi)
if vop2_dst_hi: # .h: write to high 16 bits
V[vdst] = (V[vdst] & 0x0000ffff) | ((result['d0'] & 0xffff) << 16)
else: # .l: write to low 16 bits
V[vdst] = (V[vdst] & 0xffff0000) | (result['d0'] & 0xffff)
# VOP3 16-bit ops: opsel[3] controls hi/lo destination
if opsel & 8: V[vdst] = (V[vdst] & 0x0000ffff) | ((result['d0'] & 0xffff) << 16)
else: V[vdst] = (V[vdst] & 0xffff0000) | (result['d0'] & 0xffff)
else:
V[vdst] = result['d0'] & 0xffffffff

View file

@ -35,18 +35,12 @@ def _isnan(x):
try: return math.isnan(float(x))
except (TypeError, ValueError): return False
def _isquietnan(x):
"""Check if x is a quiet NaN.
f16: exponent=31, bit9=1, mantissa!=0
f32: exponent=255, bit22=1, mantissa!=0
f64: exponent=2047, bit51=1, mantissa!=0
"""
"""Check if x is a quiet NaN. For f32: exponent=255, bit22=1, mantissa!=0"""
try:
if not math.isnan(float(x)): return False
# Get raw bits from TypedView or similar object with _reg attribute
if hasattr(x, '_reg') and hasattr(x, '_bits'):
bits = x._reg._val & ((1 << x._bits) - 1)
if x._bits == 16:
return ((bits >> 10) & 0x1f) == 31 and ((bits >> 9) & 1) == 1 and (bits & 0x3ff) != 0
if x._bits == 32:
return ((bits >> 23) & 0xff) == 255 and ((bits >> 22) & 1) == 1 and (bits & 0x7fffff) != 0
if x._bits == 64:
@ -54,18 +48,12 @@ def _isquietnan(x):
return True # Default to quiet NaN if we can't determine bit pattern
except (TypeError, ValueError): return False
def _issignalnan(x):
"""Check if x is a signaling NaN.
f16: exponent=31, bit9=0, mantissa!=0
f32: exponent=255, bit22=0, mantissa!=0
f64: exponent=2047, bit51=0, mantissa!=0
"""
"""Check if x is a signaling NaN. For f32: exponent=255, bit22=0, mantissa!=0"""
try:
if not math.isnan(float(x)): return False
# Get raw bits from TypedView or similar object with _reg attribute
if hasattr(x, '_reg') and hasattr(x, '_bits'):
bits = x._reg._val & ((1 << x._bits) - 1)
if x._bits == 16:
return ((bits >> 10) & 0x1f) == 31 and ((bits >> 9) & 1) == 0 and (bits & 0x3ff) != 0
if x._bits == 32:
return ((bits >> 23) & 0xff) == 255 and ((bits >> 22) & 1) == 0 and (bits & 0x7fffff) != 0
if x._bits == 64:
@ -85,11 +73,7 @@ def floor(x):
def ceil(x):
x = float(x)
return x if math.isnan(x) or math.isinf(x) else float(math.ceil(x))
class _SafeFloat(float):
"""Float subclass that uses _div for division to handle 0/inf correctly."""
def __truediv__(self, o): return _div(float(self), float(o))
def __rtruediv__(self, o): return _div(float(o), float(self))
def sqrt(x): return _SafeFloat(math.sqrt(x)) if x >= 0 else _SafeFloat(float("nan"))
def sqrt(x): return math.sqrt(x) if x >= 0 else float("nan")
def log2(x): return math.log2(x) if x > 0 else (float("-inf") if x == 0 else float("nan"))
i32_to_f32 = u32_to_f32 = i32_to_f64 = u32_to_f64 = f32_to_f64 = f64_to_f32 = float
def f32_to_i32(f):
@ -123,10 +107,7 @@ def u4_to_u32(v): return int(v) & 0xf
def _sign(f): return 1 if math.copysign(1.0, f) < 0 else 0
def _mantissa_f32(f): return struct.unpack("<I", struct.pack("<f", f))[0] & 0x7fffff if not (math.isinf(f) or math.isnan(f)) else 0
def _ldexp(m, e): return math.ldexp(m, e)
def isEven(x):
x = float(x)
if math.isinf(x) or math.isnan(x): return False
return int(x) % 2 == 0
def isEven(x): return int(x) % 2 == 0
def fract(x): return x - math.floor(x)
PI = math.pi
def sin(x):
@ -280,7 +261,7 @@ def f32_to_u8(f): return max(0, min(255, int(f))) if not math.isnan(f) else 0
def mantissa(f):
if f == 0.0 or math.isinf(f) or math.isnan(f): return f
m, _ = math.frexp(f)
return m # AMD V_FREXP_MANT returns mantissa in [0.5, 1.0) range
return math.copysign(m * 2.0, f)
def signext_from_bit(val, bit):
bit = int(bit)
if bit == 0: return 0
@ -301,7 +282,6 @@ __all__ = [
# Constants
'WAVE32', 'WAVE64', 'MASK32', 'MASK64', 'WAVE_MODE', 'DENORM', 'OVERFLOW_F32', 'UNDERFLOW_F32',
'OVERFLOW_F64', 'UNDERFLOW_F64', 'MAX_FLOAT_F32', 'ROUND_MODE', 'cvtToQuietNAN', 'DST', 'INF', 'PI',
'TWO_OVER_PI_1201',
# Aliases for pseudocode
's_ff1_i32_b32', 's_ff1_i32_b64', 'GT_NEG_ZERO', 'LT_NEG_ZERO',
'isNAN', 'isQuietNAN', 'isSignalNAN', 'fma', 'ldexp', 'sign', 'exponent', 'F', 'signext',
@ -342,7 +322,7 @@ def F(x):
if isinstance(x, int): return _f32(x) # int -> interpret as f32 bits
if isinstance(x, TypedView): return x # preserve TypedView for bit-pattern checks
return float(x) # already a float or float-like
signext = lambda x: int(x) # sign-extend to full width - already handled by Python's arbitrary precision ints
signext = lambda x: x
pack = lambda hi, lo: ((int(hi) & 0xffff) << 16) | (int(lo) & 0xffff)
pack32 = lambda hi, lo: ((int(hi) & 0xffffffff) << 32) | (int(lo) & 0xffffffff)
_pack, _pack32 = pack, pack32 # Aliases for internal use
@ -360,14 +340,12 @@ class _Inf:
f16 = f32 = f64 = float('inf')
def __neg__(self): return _NegInf()
def __pos__(self): return self
def __float__(self): return float('inf')
def __eq__(self, other): return float(other) == float('inf') if not isinstance(other, _NegInf) else False
def __req__(self, other): return self.__eq__(other)
class _NegInf:
f16 = f32 = f64 = float('-inf')
def __neg__(self): return _Inf()
def __pos__(self): return self
def __float__(self): return float('-inf')
def __eq__(self, other): return float(other) == float('-inf') if not isinstance(other, _Inf) else False
def __req__(self, other): return self.__eq__(other)
INF = _Inf()
@ -383,31 +361,6 @@ DST = None # Placeholder, will be set in context
MASK32, MASK64 = 0xffffffff, 0xffffffffffffffff
# 2/PI with 1201 bits of precision for V_TRIG_PREOP_F64
# Computed as: int((2/pi) * 2^1201) - this is the fractional part of 2/pi scaled to integer
# The MSB (bit 1200) corresponds to 2^0 position in the fraction 0.b1200 b1199 ... b1 b0
_TWO_OVER_PI_1201_RAW = 0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6
class _BigInt:
"""Wrapper for large integers that supports bit slicing [high:low]."""
__slots__ = ('_val',)
def __init__(self, val): self._val = val
def __getitem__(self, key):
if isinstance(key, slice):
high, low = key.start, key.stop
if high < low: high, low = low, high # Handle reversed slice
mask = (1 << (high - low + 1)) - 1
return (self._val >> low) & mask
return (self._val >> key) & 1
def __int__(self): return self._val
def __index__(self): return self._val
def __lshift__(self, n): return self._val << int(n)
def __rshift__(self, n): return self._val >> int(n)
def __and__(self, n): return self._val & int(n)
def __or__(self, n): return self._val | int(n)
TWO_OVER_PI_1201 = _BigInt(_TWO_OVER_PI_1201_RAW)
class _WaveMode:
IEEE = False
WAVE_MODE = _WaveMode()
@ -547,17 +500,6 @@ class TypedView:
def __bool__(s): return bool(int(s))
# Allow chained type access like jump_addr.i64 when jump_addr is already a TypedView
# These just return self or convert appropriately
@property
def i64(s): return s if s._bits == 64 and s._signed else int(s)
@property
def u64(s): return s if s._bits == 64 and not s._signed else int(s) & MASK64
@property
def i32(s): return s if s._bits == 32 and s._signed else _sext(int(s) & MASK32, 32)
@property
def u32(s): return s if s._bits == 32 and not s._signed else int(s) & MASK32
class Reg:
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works."""
__slots__ = ('_val',)
@ -581,7 +523,6 @@ class Reg:
bf16 = property(lambda s: TypedView(s, 16, is_float=True, is_bf16=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _ibf16(float(v))) & 0xffff)))
u8 = property(lambda s: TypedView(s, 8))
i8 = property(lambda s: TypedView(s, 8, signed=True))
u1 = property(lambda s: TypedView(s, 1)) # single bit
def __getitem__(s, key):
if isinstance(key, slice): return SliceProxy(s, int(key.start), int(key.stop))
@ -624,6 +565,19 @@ class Reg:
def __eq__(s, o): return s._val == int(o)
def __ne__(s, o): return s._val != int(o)
class ExecContext:
"""Execution context for running compiled pseudocode strings (for testing)."""
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=0xffffffff):
self.S0, self.S1, self.S2, self.D0, self.D1 = Reg(s0), Reg(s1), Reg(s2), Reg(d0), Reg(0)
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
self.tmp, self.saveexec, self.laneId = Reg(0), Reg(exec_mask), lane
def run(self, code: str):
if not code.strip(): return
code = code if '\n' in code else compile_pseudocode(code)
exec(code, {**globals(), 'Reg': Reg, 'SliceProxy': SliceProxy}, self.__dict__)
def result(self) -> dict:
return {'d0': self.D0._val, 'd1': self.D1._val, 'scc': self.SCC._val & 1, 'vcc': self.VCC._val, 'exec': self.EXEC._val}
# ═══════════════════════════════════════════════════════════════════════════════
# COMPILER: pseudocode -> Python (minimal transforms)
# ═══════════════════════════════════════════════════════════════════════════════
@ -642,7 +596,7 @@ def compile_pseudocode(pseudocode: str) -> str:
joined_lines.append(line)
lines = []
indent, need_pass, in_first_match_loop = 0, False, False
indent, need_pass = 0, False
for line in joined_lines:
line = line.strip()
if not line or line.startswith('//'): continue
@ -671,14 +625,14 @@ def compile_pseudocode(pseudocode: str) -> str:
elif line.startswith('endfor'):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
need_pass, in_first_match_loop = False, False
need_pass = False
elif line.startswith('declare '):
pass
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
start, end = _expr(m[2].strip()), _expr(m[3].strip())
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
indent += 1
need_pass, in_first_match_loop = True, True
need_pass = True
elif '=' in line and not line.startswith('=='):
need_pass = False
line = line.rstrip(';')
@ -697,20 +651,20 @@ def compile_pseudocode(pseudocode: str) -> str:
break
else:
lhs, rhs = line.split('=', 1)
lhs_s, rhs_s = lhs.strip(), rhs.strip()
stmt = _assign(lhs_s, _expr(rhs_s))
# CLZ/CTZ pattern: assignment of loop var to tmp/D0.i32 in first-match loop needs break
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
stmt += "; break"
lines.append(' ' * indent + stmt)
lines.append(' ' * indent + _assign(lhs.strip(), _expr(rhs.strip())))
# If we ended with a control statement that needs a body, add pass
if need_pass: lines.append(' ' * indent + "pass")
return '\n'.join(lines)
def _assign(lhs: str, rhs: str) -> str:
"""Generate assignment. Bare tmp/SCC/etc get wrapped in Reg()."""
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
return f"{lhs} = Reg({rhs})"
"""Generate assignment. Outputs modify Reg in-place via ._val."""
# Output registers and tmp: modify in-place so caller sees changes
if lhs in ('SCC', 'VCC', 'EXEC', 'D0', 'D1', 'tmp'):
return f"{lhs}._val = int({rhs})"
# saveexec needs to be a new Reg for typed accessor access
if lhs == 'saveexec':
return f"{lhs} = Reg(int({rhs}))"
# Other locals: natural style
return f"{lhs} = {rhs}"
def _expr(e: str) -> str:
@ -726,9 +680,6 @@ def _expr(e: str) -> str:
return f'_pack({hi}, {lo})'
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
# Special constant: 1201'B(2.0 / PI) -> TWO_OVER_PI_1201 (precomputed 1201-bit 2/pi)
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
# Literals: 1'0U -> 0, 32'I(x) -> (x), B(x) -> (x)
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
e = re.sub(r"\d+'[FIBU]\(", "(", e)
@ -795,50 +746,7 @@ def _expr(e: str) -> str:
e = f'(({t}) if ({cond}) else ({f}))'
return e
# ═══════════════════════════════════════════════════════════════════════════════
# EXECUTION CONTEXT
# ═══════════════════════════════════════════════════════════════════════════════
class ExecContext:
"""Context for running compiled pseudocode."""
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=MASK32, literal=0, vgprs=None, src0_idx=0, vdst_idx=0):
self.S0, self.S1, self.S2 = Reg(s0), Reg(s1), Reg(s2)
self.D0, self.D1 = Reg(d0), Reg(0)
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
self.tmp, self.saveexec = Reg(0), Reg(exec_mask)
self.lane, self.laneId, self.literal = lane, lane, literal
self.SIMM16, self.SIMM32 = Reg(literal), Reg(literal)
self.VGPR = vgprs if vgprs is not None else {}
self.SRC0, self.VDST = Reg(src0_idx), Reg(vdst_idx)
def run(self, code: str):
"""Execute compiled code."""
# Start with module globals (helpers, aliases), then add instance-specific bindings
ns = dict(globals())
ns.update({
'S0': self.S0, 'S1': self.S1, 'S2': self.S2, 'D0': self.D0, 'D1': self.D1,
'SCC': self.SCC, 'VCC': self.VCC, 'EXEC': self.EXEC,
'EXEC_LO': SliceProxy(self.EXEC, 31, 0), 'EXEC_HI': SliceProxy(self.EXEC, 63, 32),
'tmp': self.tmp, 'saveexec': self.saveexec,
'lane': self.lane, 'laneId': self.laneId, 'literal': self.literal,
'SIMM16': self.SIMM16, 'SIMM32': self.SIMM32,
'VGPR': self.VGPR, 'SRC0': self.SRC0, 'VDST': self.VDST,
})
exec(code, ns)
# Sync rebinds: if register was reassigned to new Reg or value, copy it back
def _sync(ctx_reg, ns_val):
if isinstance(ns_val, Reg): ctx_reg._val = ns_val._val
else: ctx_reg._val = int(ns_val) & MASK64
if ns.get('SCC') is not self.SCC: _sync(self.SCC, ns['SCC'])
if ns.get('VCC') is not self.VCC: _sync(self.VCC, ns['VCC'])
if ns.get('EXEC') is not self.EXEC: _sync(self.EXEC, ns['EXEC'])
if ns.get('D0') is not self.D0: _sync(self.D0, ns['D0'])
if ns.get('D1') is not self.D1: _sync(self.D1, ns['D1'])
if ns.get('tmp') is not self.tmp: _sync(self.tmp, ns['tmp'])
if ns.get('saveexec') is not self.saveexec: _sync(self.saveexec, ns['saveexec'])
def result(self) -> dict:
return {"d0": self.D0._val, "scc": self.SCC._val & 1}
# ═══════════════════════════════════════════════════════════════════════════════
# PDF EXTRACTION AND CODE GENERATION
@ -849,14 +757,14 @@ INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
# Patterns that can't be handled by the DSL (require special handling in emu.py)
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'PC =', 'PC=', 'PC+', '= PC', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'CVT_OFF_TABLE', 'ThreadMask',
'S1[i', 'C.i32', 'S[i]', 'in[',
'S1[i', 'C.i32', 'S[i]', 'in[', '2.0 / PI',
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF
def extract_pseudocode(text: str) -> str | None:
"""Extract pseudocode from an instruction description snippet."""
lines, result, depth, in_lambda = text.split('\n'), [], 0, 0
lines, result, depth = text.split('\n'), [], 0
for line in lines:
s = line.strip()
if not s: continue
@ -865,17 +773,12 @@ def extract_pseudocode(text: str) -> str | None:
# Skip document headers (RDNA or CDNA)
if s.startswith('"RDNA') or s.startswith('AMD ') or s.startswith('CDNA'): continue
if s.startswith('Notes') or s.startswith('Functional examples'): break
# Track lambda definitions (e.g., BYTE_PERMUTE = lambda(data, sel) (...))
if '= lambda(' in s: in_lambda += 1; continue
if in_lambda > 0:
if s.endswith(');'): in_lambda -= 1
continue
if s.startswith('if '): depth += 1
elif s.startswith('endif'): depth = max(0, depth - 1)
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
is_code = (
any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =']) or
any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =']) or
any(p in s for p in ['D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s)
@ -1017,8 +920,11 @@ from extra.assembly.amd.pcode import *
try:
code = compile_pseudocode(pc)
# NOTE: Do NOT add more code.replace() hacks here. Fix issues properly in the DSL
# (compile_pseudocode, helper functions, or Reg/TypedView classes) instead.
# CLZ/CTZ: The PDF pseudocode searches for the first 1 bit but doesn't break.
# Hardware stops at first match. SOP1 uses tmp=i, VOP1/VOP3 use D0.i32=i
if 'CLZ' in op.name or 'CTZ' in op.name:
code = code.replace('tmp = Reg(i)', 'tmp = Reg(i); break')
code = code.replace('D0.i32 = i', 'D0.i32 = i; break')
# V_DIV_FMAS_F32/F64: PDF page 449 says 2^32/2^64 but hardware behavior is more complex.
# The scale direction depends on S2 (the addend): if exponent(S2) > 127 (i.e., S2 >= 2.0),
# scale by 2^+64 (to unscale a numerator that was scaled). Otherwise scale by 2^-64
@ -1040,7 +946,7 @@ from extra.assembly.amd.pcode import *
# Fix 1: Set VCC=1 when zero operands produce NaN
code = code.replace(
'D0.f32 = float("nan")',
'VCC = Reg(0x1); D0.f32 = float("nan")')
'VCC._val = 0x1; D0.f32 = float("nan")')
# Fix 2: Denorm denom returns NaN. Must check this AFTER all VCC-setting logic runs.
# Insert at end of all branches, before the final result is used
code = code.replace(
@ -1051,26 +957,26 @@ from extra.assembly.amd.pcode import *
# Fix 3: Tiny numer should set VCC=1
code = code.replace(
'elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)',
'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
'elif exponent(S2.f32) <= 23:\n VCC._val = 0x1; D0.f32 = ldexp(S0.f32, 64)')
# Fix 4: S2/S1 would be denorm - don't scale, just set VCC
code = code.replace(
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)',
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC._val = int(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)',
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC._val = 0x1')
if op.name == 'V_DIV_SCALE_F64':
# Same fixes for f64 version
code = code.replace(
'D0.f64 = float("nan")',
'VCC = Reg(0x1); D0.f64 = float("nan")')
'VCC._val = 0x1; D0.f64 = float("nan")')
code = code.replace(
'elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)',
'elif False:\n pass # denorm check moved to end')
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
code = code.replace(
'elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)',
'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
'elif exponent(S2.f64) <= 52:\n VCC._val = 0x1; D0.f64 = ldexp(S0.f64, 128)')
code = code.replace(
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)',
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC._val = int(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)',
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC._val = 0x1')
# V_DIV_FIXUP_F32/F64: PDF doesn't check isNAN(S0), but hardware returns OVERFLOW if S0 is NaN.
# When division fails (e.g., due to denorm denom), S0 becomes NaN, and fixup should return ±inf.
if op.name == 'V_DIV_FIXUP_F32':
@ -1081,86 +987,23 @@ from extra.assembly.amd.pcode import *
code = code.replace(
'D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
# V_TRIG_PREOP_F64: AMD pseudocode uses (x << shift) & mask but mask needs to extract TOP bits.
# The PDF shows: result = 64'F((1201'B(2.0/PI)[1200:0] << shift) & 1201'0x1fffffffffffff)
# Issues to fix:
# 1. After left shift, the interesting bits are at the top, not bottom - need >> (1201-53)
# 2. shift.u32 fails because shift is a plain int after * 53 - use int(shift)
# 3. 64'F(...) means convert int to float (not interpret as bit pattern) - use float()
if op.name == 'V_TRIG_PREOP_F64':
code = code.replace(
'result = F((TWO_OVER_PI_1201[1200 : 0] << shift.u32) & 0x1fffffffffffff)',
'result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)')
# Detect flags for result handling
is_64 = any(p in pc for p in ['D0.u64', 'D0.b64', 'D0.f64', 'D0.i64', 'D1.u64', 'D1.b64', 'D1.f64', 'D1.i64'])
has_d1 = '{ D1' in pc
if has_d1: is_64 = True
is_cmp = (cls_name == 'VOPCOp' or cls_name == 'VOP3Op') and 'D0.u64[laneId]' in pc
is_cmpx = (cls_name == 'VOPCOp' or cls_name == 'VOP3Op') and 'EXEC.u64[laneId]' in pc # V_CMPX writes to EXEC per-lane
# V_DIV_SCALE passes through S0 if no branch taken
is_div_scale = 'DIV_SCALE' in op.name
# VOP3SD instructions that write VCC per-lane (either via VCC.u64[laneId] or by setting VCC = 0/1)
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
# Instructions that use/modify PC
has_pc = 'PC' in pc
# Generate function with indented body
# SIMM16/SIMM32 (inline literal constants) are passed as S2
code = code.replace('SIMM16', 'S2').replace('SIMM32', 'S2')
# Generate function with standard signature
fn_name = f"_{cls_name}_{op.name}"
lines.append(f"def {fn_name}(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0):")
# Add original pseudocode as comment
lines.append(f"def {fn_name}(S0, S1, S2, D0, D1, SCC, VCC, EXEC, tmp, laneId):")
for pc_line in pc.split('\n'):
lines.append(f" # {pc_line}")
# Only create Reg objects for registers actually used in the pseudocode
# Add EXEC_LO/EXEC_HI if needed
combined = code + pc
regs = [('S0', 'Reg(s0)'), ('S1', 'Reg(s1)'), ('S2', 'Reg(s2)'),
('D0', 'Reg(s0)' if is_div_scale else 'Reg(d0)'), ('D1', 'Reg(0)'),
('SCC', 'Reg(scc)'), ('VCC', 'Reg(vcc)'), ('EXEC', 'Reg(exec_mask)'),
('tmp', 'Reg(0)'), ('saveexec', 'Reg(exec_mask)'), ('laneId', 'lane'),
('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)'),
('PC', 'Reg(pc)')] # PC is passed in as byte address
used = {name for name, _ in regs if name in combined}
# EXEC_LO/EXEC_HI need EXEC
if 'EXEC_LO' in combined or 'EXEC_HI' in combined: used.add('EXEC')
# VCCZ/EXECZ need VCC/EXEC
if 'VCCZ' in combined: used.add('VCC')
if 'EXECZ' in combined: used.add('EXEC')
for name, init in regs:
if name in used: lines.append(f" {name} = {init}")
if 'EXEC_LO' in combined: lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
if 'EXEC_HI' in combined: lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
# VCCZ = 1 if VCC == 0, EXECZ = 1 if EXEC == 0
if 'VCCZ' in combined: lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)")
if 'EXECZ' in combined: lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)")
# Add compiled pseudocode with markers
lines.append(" # --- compiled pseudocode ---")
for line in code.split('\n'):
lines.append(f" {line}")
lines.append(" # --- end pseudocode ---")
# Generate result dict - use raw params if Reg wasn't created
d0_val = "D0._val" if 'D0' in used else "d0"
scc_val = "SCC._val & 1" if 'SCC' in used else "scc & 1"
lines.append(f" result = {{'d0': {d0_val}, 'scc': {scc_val}}}")
if has_sdst:
lines.append(" result['vcc_lane'] = (VCC._val >> lane) & 1")
elif 'VCC' in used:
lines.append(" if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1")
if is_cmpx:
lines.append(" result['exec_lane'] = (EXEC._val >> lane) & 1")
elif 'EXEC' in used:
lines.append(" if EXEC._val != exec_mask: result['exec'] = EXEC._val")
if is_cmp:
lines.append(" result['vcc_lane'] = (D0._val >> lane) & 1")
if is_64:
lines.append(" result['d0_64'] = True")
if has_d1:
lines.append(" result['d1'] = D1._val & 1")
if has_pc:
# Return new PC as absolute byte address, emulator will compute delta
# Handle negative values (backward jumps): PC._val is stored as unsigned, convert to signed
lines.append(" _pc = PC._val if PC._val < 0x8000000000000000 else PC._val - 0x10000000000000000")
lines.append(" result['new_pc'] = _pc # absolute byte address")
lines.append(" return result")
code_lines = [line for line in code.split('\n') if line.strip()]
if code_lines:
for line in code_lines:
lines.append(f" {line}")
else:
lines.append(" pass")
lines.append("")
fn_entries.append((op, fn_name))
@ -1181,9 +1024,8 @@ from extra.assembly.amd.pcode import *
if 'VOP3Op' in enum_names:
lines.append('''
# V_WRITELANE_B32: Write scalar to specific lane's VGPR (not in PDF pseudocode)
def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
wr_lane = s1 & 0x1f # lane select (5 bits for wave32)
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
def _VOP3Op_V_WRITELANE_B32(S0, S1, S2, D0, D1, SCC, VCC, EXEC, tmp, laneId):
return (int(S1) & 0x1f, int(S0) & 0xffffffff) # (wr_lane, value)
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
''')

View file

@ -4,7 +4,9 @@ import unittest
from extra.assembly.amd.pcode import (Reg, TypedView, SliceProxy, ExecContext, compile_pseudocode, _expr, MASK32, MASK64,
_f32, _i32, _f16, _i16, f32_to_f16, _isnan, _bf16, _ibf16, bf16_to_f32, f32_to_bf16,
BYTE_PERMUTE, v_sad_u8, v_msad_u8)
from extra.assembly.amd.autogen.rdna3.gen_pcode import _VOP3SDOp_V_DIV_SCALE_F32, _VOPCOp_V_CMP_CLASS_F32
from extra.assembly.amd.autogen.rdna3.gen_pcode import VOP3SDOp_FUNCTIONS, VOPCOp_FUNCTIONS
from extra.assembly.amd.autogen.rdna3 import VOP3SDOp, VOPCOp
from extra.assembly.amd.emu import _run_pcode
class TestReg(unittest.TestCase):
def test_u32_read(self):
@ -234,7 +236,8 @@ class TestPseudocodeRegressions(unittest.TestCase):
s0 = 0x3f800000 # 1.0
s1 = 0x40400000 # 3.0
s2 = 0x3f800000 # 1.0 (numerator)
result = _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, 0, 0, 0, 0, 0xffffffff, 0, None, {})
fn = VOP3SDOp_FUNCTIONS[VOP3SDOp.V_DIV_SCALE_F32]
result = _run_pcode(fn, VOP3SDOp, VOP3SDOp.V_DIV_SCALE_F32, s0, s1, s2, 0, 0, 0, 0, 0xffffffff, 0)
# Must always have vcc_lane in result
self.assertIn('vcc_lane', result, "V_DIV_SCALE_F32 must always return vcc_lane")
self.assertEqual(result['vcc_lane'], 0, "vcc_lane should be 0 when no scaling needed")
@ -244,19 +247,20 @@ class TestPseudocodeRegressions(unittest.TestCase):
Bug: isQuietNAN and isSignalNAN both used math.isnan which can't distinguish them."""
quiet_nan = 0x7fc00000 # quiet NaN: exponent=255, bit22=1
signal_nan = 0x7f800001 # signaling NaN: exponent=255, bit22=0
fn = VOPCOp_FUNCTIONS[VOPCOp.V_CMP_CLASS_F32]
# Test quiet NaN detection (bit 1 in mask)
s1_quiet = 0b0000000010 # bit 1 = quiet NaN
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
result = _run_pcode(fn, VOPCOp, VOPCOp.V_CMP_CLASS_F32, quiet_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0)
self.assertEqual(result['vcc_lane'], 1, "Should detect quiet NaN with quiet NaN mask")
# Test signaling NaN detection (bit 0 in mask)
s1_signal = 0b0000000001 # bit 0 = signaling NaN
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
result = _run_pcode(fn, VOPCOp, VOPCOp.V_CMP_CLASS_F32, signal_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0)
self.assertEqual(result['vcc_lane'], 1, "Should detect signaling NaN with signaling NaN mask")
# Test that quiet NaN doesn't match signaling NaN mask
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
result = _run_pcode(fn, VOPCOp, VOPCOp.V_CMP_CLASS_F32, quiet_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0)
self.assertEqual(result['vcc_lane'], 0, "Quiet NaN should not match signaling NaN mask")
# Test that signaling NaN doesn't match quiet NaN mask
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
result = _run_pcode(fn, VOPCOp, VOPCOp.V_CMP_CLASS_F32, signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0)
self.assertEqual(result['vcc_lane'], 0, "Signaling NaN should not match quiet NaN mask")
def test_isnan_with_typed_view(self):