mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
3 commits
master
...
only_reg_e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a1190b729 | ||
|
|
433248c998 | ||
|
|
7f139a934f |
4 changed files with 1751 additions and 11107 deletions
File diff suppressed because it is too large
Load diff
|
|
@ -1,10 +1,9 @@
|
|||
# RDNA3 emulator - executes compiled pseudocode from AMD ISA PDF
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
import ctypes, os
|
||||
import ctypes
|
||||
from extra.assembly.amd.dsl import Inst, RawImm
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.pcode import _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
from extra.assembly.amd.pcode import _f32, _i32, _sext, _f16, _i16, _f64, _i64, Reg, SliceProxy
|
||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
|
||||
from extra.assembly.amd.autogen.rdna3 import (
|
||||
SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD, SrcEnum,
|
||||
|
|
@ -18,7 +17,6 @@ VCC_LO, VCC_HI, NULL, EXEC_LO, EXEC_HI, SCC = SrcEnum.VCC_LO, SrcEnum.VCC_HI, Sr
|
|||
# VOP3 ops that use 64-bit operands (and thus 64-bit literals when src is 255)
|
||||
# Exception: V_LDEXP_F64 has 32-bit integer src1, so literal should NOT be 64-bit when src1=255
|
||||
_VOP3_64BIT_OPS = {op.value for op in VOP3Op if op.name.endswith(('_F64', '_B64', '_I64', '_U64'))}
|
||||
_VOPC_64BIT_OPS = {op.value for op in VOPCOp if op.name.endswith(('_F64', '_B64', '_I64', '_U64'))}
|
||||
# Ops where src1 is 32-bit (exponent/shift amount) even though the op name suggests 64-bit
|
||||
_VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value}
|
||||
# Ops with 16-bit types in name (for source/dest handling)
|
||||
|
|
@ -26,18 +24,12 @@ _VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value}
|
|||
_VOP3_16BIT_OPS = {op for op in VOP3Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16')) and 'SAD' not in op.name}
|
||||
_VOP1_16BIT_OPS = {op for op in VOP1Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
|
||||
_VOP2_16BIT_OPS = {op for op in VOP2Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
|
||||
_VOPC_16BIT_OPS = {op for op in VOPCOp if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
|
||||
# CVT ops with 32/64-bit source (despite 16-bit in name)
|
||||
_CVT_32_64_SRC_OPS = {op for op in VOP3Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))} | \
|
||||
{op for op in VOP1Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))}
|
||||
# CVT ops with 32-bit destination (convert FROM 16-bit TO 32-bit): V_CVT_F32_F16, V_CVT_I32_I16, V_CVT_U32_U16
|
||||
_CVT_32_DST_OPS = {op for op in VOP3Op if op.name.startswith('V_CVT_') and any(s in op.name for s in ('F32_F16', 'I32_I16', 'U32_U16', 'I32_F16', 'U32_F16'))} | \
|
||||
{op for op in VOP1Op if op.name.startswith('V_CVT_') and any(s in op.name for s in ('F32_F16', 'I32_I16', 'U32_U16', 'I32_F16', 'U32_F16'))}
|
||||
# 16-bit dst ops (PACK has 32-bit dst despite F16 in name, CVT to 32-bit has 32-bit dst)
|
||||
_VOP3_16BIT_DST_OPS = {op for op in _VOP3_16BIT_OPS if 'PACK' not in op.name} - _CVT_32_DST_OPS
|
||||
_VOP1_16BIT_DST_OPS = {op for op in _VOP1_16BIT_OPS if 'PACK' not in op.name} - _CVT_32_DST_OPS
|
||||
# VOP1 16-bit source ops (excluding CVT ops with 32/64-bit source) - for VOP1 e32, .h encoded in register index
|
||||
_VOP1_16BIT_SRC_OPS = _VOP1_16BIT_OPS - _CVT_32_64_SRC_OPS
|
||||
# 16-bit dst ops (PACK has 32-bit dst despite F16 in name)
|
||||
_VOP3_16BIT_DST_OPS = {op for op in _VOP3_16BIT_OPS if 'PACK' not in op.name}
|
||||
_VOP1_16BIT_DST_OPS = {op for op in _VOP1_16BIT_OPS if 'PACK' not in op.name}
|
||||
|
||||
# Inline constants for src operands 128-254. Build tables for f32, f16, and f64 formats.
|
||||
import struct as _struct
|
||||
|
|
@ -97,6 +89,36 @@ def _get_compiled() -> dict:
|
|||
if _COMPILED is None: _COMPILED = get_compiled_functions()
|
||||
return _COMPILED
|
||||
|
||||
def _run_pcode(fn, op_cls, op, s0, s1, s2, d0, scc, vcc, lane, exec_mask, vdst_idx):
|
||||
"""Create Regs, run pseudocode, extract results."""
|
||||
# Determine flags from op_cls and op.name
|
||||
is_div_scale = 'DIV_SCALE' in op.name
|
||||
is_64 = op.name.endswith(('_B64', '_I64', '_U64', '_F64')) or op.name in ('V_MAD_U64_U32', 'V_MAD_I64_I32')
|
||||
is_cmp = op_cls.__name__ == 'VOPCOp' and not op.name.startswith('V_CMPX')
|
||||
is_cmpx = op_cls.__name__ == 'VOPCOp' and op.name.startswith('V_CMPX')
|
||||
has_sdst = op_cls.__name__ == 'VOP3SDOp'
|
||||
|
||||
# Create Regs - D0 gets s0 for DIV_SCALE (passthrough behavior)
|
||||
S0, S1, S2 = Reg(s0), Reg(s1), Reg(s2)
|
||||
D0, D1 = Reg(s0 if is_div_scale else d0), Reg(0)
|
||||
SCC, VCC, EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
|
||||
tmp = Reg(0)
|
||||
|
||||
# Call pseudocode
|
||||
ret = fn(S0, S1, S2, D0, D1, SCC, VCC, EXEC, tmp, lane)
|
||||
|
||||
# Build result
|
||||
result = {'d0': D0._val, 'scc': SCC._val & 1}
|
||||
if has_sdst or VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1
|
||||
if is_cmpx: result['exec_lane'] = (EXEC._val >> lane) & 1
|
||||
elif EXEC._val != exec_mask: result['exec'] = EXEC._val
|
||||
if is_cmp: result['vcc_lane'] = (D0._val >> lane) & 1
|
||||
if is_64: result['d0_64'] = True
|
||||
if D1._val: result['d1'] = D1._val & 1
|
||||
# V_WRITELANE_B32 returns (wr_lane, value) directly
|
||||
if ret is not None: result['vgpr_write'] = (ret[0], vdst_idx, ret[1])
|
||||
return result
|
||||
|
||||
class WaveState:
|
||||
__slots__ = ('sgpr', 'vgpr', 'scc', 'pc', 'literal', '_pend_sgpr')
|
||||
def __init__(self):
|
||||
|
|
@ -147,7 +169,21 @@ class WaveState:
|
|||
for reg, val in self._pend_sgpr.items(): self.sgpr[reg] = val
|
||||
self._pend_sgpr.clear()
|
||||
|
||||
|
||||
# Instruction decode
|
||||
def decode_format(word: int) -> tuple[type[Inst] | None, bool]:
|
||||
hi2 = (word >> 30) & 0x3
|
||||
if hi2 == 0b11:
|
||||
enc = (word >> 26) & 0xf
|
||||
if enc == 0b1101: return SMEM, True
|
||||
if enc == 0b0101:
|
||||
op = (word >> 16) & 0x3ff
|
||||
return (VOP3SD, True) if op in (288, 289, 290, 764, 765, 766, 767, 768, 769, 770) else (VOP3, True)
|
||||
return {0b0011: (VOP3P, True), 0b0110: (DS, True), 0b0111: (FLAT, True), 0b0010: (VOPD, True)}.get(enc, (None, True))
|
||||
if hi2 == 0b10:
|
||||
enc = (word >> 23) & 0x7f
|
||||
return {0b1111101: (SOP1, False), 0b1111110: (SOPC, False), 0b1111111: (SOPP, False)}.get(enc, (SOPK, False) if ((word >> 28) & 0xf) == 0b1011 else (SOP2, False))
|
||||
enc = (word >> 25) & 0x7f
|
||||
return (VOPC, False) if enc == 0b0111110 else (VOP1, False) if enc == 0b0111111 else (VOP2, False)
|
||||
|
||||
def _unwrap(v) -> int: return v.val if isinstance(v, RawImm) else v.value if hasattr(v, 'value') else v
|
||||
|
||||
|
|
@ -155,10 +191,10 @@ def decode_program(data: bytes) -> Program:
|
|||
result: Program = {}
|
||||
i = 0
|
||||
while i < len(data):
|
||||
try: inst_class = detect_format(data[i:])
|
||||
except ValueError: break # stop at invalid instruction (padding/metadata after code)
|
||||
word = int.from_bytes(data[i:i+4], 'little')
|
||||
inst_class, is_64 = decode_format(word)
|
||||
if inst_class is None: i += 4; continue
|
||||
base_size = inst_class._size()
|
||||
base_size = 8 if is_64 else 4
|
||||
# Pass enough data for potential 64-bit literal (base + 8 bytes max)
|
||||
inst = inst_class.from_bytes(data[i:i+base_size+8])
|
||||
for name, val in inst._values.items(): setattr(inst, name, _unwrap(val))
|
||||
|
|
@ -173,7 +209,7 @@ def decode_program(data: bytes) -> Program:
|
|||
# Exception: some ops have mixed src sizes (e.g., V_LDEXP_F64 has 32-bit src1)
|
||||
op_val = inst._values.get('op')
|
||||
if hasattr(op_val, 'value'): op_val = op_val.value
|
||||
is_64bit = (inst_class is VOP3 and op_val in _VOP3_64BIT_OPS) or (inst_class is VOPC and op_val in _VOPC_64BIT_OPS)
|
||||
is_64bit = inst_class is VOP3 and op_val in _VOP3_64BIT_OPS
|
||||
# Don't treat literal as 64-bit if the op has 32-bit src1 and src1 is the literal
|
||||
if is_64bit and op_val in _VOP3_64BIT_OPS_32BIT_SRC1 and getattr(inst, 'src1', None) == 255:
|
||||
is_64bit = False
|
||||
|
|
@ -193,11 +229,21 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
|
|||
compiled = _get_compiled()
|
||||
inst_type = type(inst)
|
||||
|
||||
# SOPP: special cases for control flow that has no pseudocode
|
||||
# SOPP: control flow (not ALU)
|
||||
if inst_type is SOPP:
|
||||
op = inst.op
|
||||
if op == SOPPOp.S_ENDPGM: return -1
|
||||
if op == SOPPOp.S_BARRIER: return -2
|
||||
if op == SOPPOp.S_BRANCH: return _sext(inst.simm16, 16)
|
||||
if op == SOPPOp.S_CBRANCH_SCC0: return _sext(inst.simm16, 16) if st.scc == 0 else 0
|
||||
if op == SOPPOp.S_CBRANCH_SCC1: return _sext(inst.simm16, 16) if st.scc == 1 else 0
|
||||
if op == SOPPOp.S_CBRANCH_VCCZ: return _sext(inst.simm16, 16) if (st.vcc & 0xffffffff) == 0 else 0
|
||||
if op == SOPPOp.S_CBRANCH_VCCNZ: return _sext(inst.simm16, 16) if (st.vcc & 0xffffffff) != 0 else 0
|
||||
if op == SOPPOp.S_CBRANCH_EXECZ: return _sext(inst.simm16, 16) if st.exec_mask == 0 else 0
|
||||
if op == SOPPOp.S_CBRANCH_EXECNZ: return _sext(inst.simm16, 16) if st.exec_mask != 0 else 0
|
||||
# Valid SOPP range is 0-61 (max defined opcode); anything above is invalid
|
||||
if op > 61: raise NotImplementedError(f"Invalid SOPP opcode {op}")
|
||||
return 0 # waits, hints, nops
|
||||
|
||||
# SMEM: memory loads (not ALU)
|
||||
if inst_type is SMEM:
|
||||
|
|
@ -207,53 +253,49 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
|
|||
for i in range(cnt): st.wsgpr(inst.sdata + i, mem_read((addr + i * 4) & 0xffffffffffffffff, 4))
|
||||
return 0
|
||||
|
||||
# SOP1: special handling for ops not in pseudocode
|
||||
if inst_type is SOP1:
|
||||
op = SOP1Op(inst.op)
|
||||
# S_GETPC_B64: Get program counter (PC is stored as byte offset, convert from words)
|
||||
if op == SOP1Op.S_GETPC_B64:
|
||||
pc_bytes = st.pc * 4 # PC is in words, convert to bytes
|
||||
st.wsgpr64(inst.sdst, pc_bytes)
|
||||
return 0
|
||||
# S_SETPC_B64: Set program counter to source value (indirect jump)
|
||||
# Returns delta such that st.pc + inst_words + delta = target_words
|
||||
if op == SOP1Op.S_SETPC_B64:
|
||||
target_bytes = st.rsrc64(inst.ssrc0, 0)
|
||||
target_words = target_bytes // 4
|
||||
inst_words = 1 # SOP1 is always 1 word
|
||||
return target_words - st.pc - inst_words
|
||||
|
||||
# Get op enum and lookup compiled function
|
||||
if inst_type is SOP1: op_cls, ssrc0, sdst = SOP1Op, inst.ssrc0, inst.sdst
|
||||
elif inst_type is SOP2: op_cls, ssrc0, sdst = SOP2Op, inst.ssrc0, inst.sdst
|
||||
elif inst_type is SOPC: op_cls, ssrc0, sdst = SOPCOp, inst.ssrc0, None
|
||||
elif inst_type is SOPK: op_cls, ssrc0, sdst = SOPKOp, inst.sdst, inst.sdst # sdst is both src and dst
|
||||
elif inst_type is SOPP: op_cls, ssrc0, sdst = SOPPOp, None, None
|
||||
else: raise NotImplementedError(f"Unknown scalar type {inst_type}")
|
||||
|
||||
# SOPP has gaps in the opcode enum - treat unknown opcodes as no-ops
|
||||
try: op = op_cls(inst.op)
|
||||
except ValueError:
|
||||
if inst_type is SOPP: return 0
|
||||
raise
|
||||
op = op_cls(inst.op)
|
||||
fn = compiled.get(op_cls, {}).get(op)
|
||||
if fn is None:
|
||||
# SOPP instructions without pseudocode (waits, hints, nops) are no-ops
|
||||
if inst_type is SOPP: return 0
|
||||
raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
|
||||
# Build context - handle 64-bit ops that need 64-bit source reads
|
||||
# 64-bit source ops: name ends with _B64, _I64, _U64 or contains _U64, _I64 before last underscore
|
||||
# Read sources - 64-bit ops need 64-bit source reads
|
||||
is_64bit_s0 = op.name.endswith(('_B64', '_I64', '_U64')) or '_U64_' in op.name or '_I64_' in op.name
|
||||
is_64bit_s0s1 = op_cls is SOPCOp and op in (SOPCOp.S_CMP_EQ_U64, SOPCOp.S_CMP_LG_U64)
|
||||
s0 = st.rsrc64(ssrc0, 0) if is_64bit_s0 or is_64bit_s0s1 else (st.rsrc(ssrc0, 0) if inst_type not in (SOPK, SOPP) else (st.rsgpr(inst.sdst) if inst_type is SOPK else 0))
|
||||
s0 = st.rsrc64(ssrc0, 0) if is_64bit_s0 or is_64bit_s0s1 else (st.rsrc(ssrc0, 0) if inst_type != SOPK else st.rsgpr(inst.sdst))
|
||||
is_64bit_sop2 = is_64bit_s0 and inst_type is SOP2
|
||||
s1 = st.rsrc64(inst.ssrc1, 0) if (is_64bit_sop2 or is_64bit_s0s1) else (st.rsrc(inst.ssrc1, 0) if inst_type in (SOP2, SOPC) else inst.simm16 if inst_type is SOPK else 0)
|
||||
s1 = st.rsrc64(inst.ssrc1, 0) if (is_64bit_sop2 or is_64bit_s0s1) else (st.rsrc(inst.ssrc1, 0) if inst_type in (SOP2, SOPC) else 0)
|
||||
s2 = inst.simm16 if inst_type is SOPK else 0 # SOPK: 16-bit immediate passed as S2
|
||||
d0 = st.rsgpr64(sdst) if (is_64bit_s0 or is_64bit_s0s1) and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
|
||||
exec_mask = st.exec_mask
|
||||
literal = inst.simm16 if inst_type in (SOPK, SOPP) else st.literal
|
||||
|
||||
# Execute compiled function - pass PC in bytes for instructions that need it
|
||||
pc_bytes = st.pc * 4
|
||||
result = fn(s0, s1, 0, d0, st.scc, st.vcc, 0, exec_mask, literal, None, {}, pc=pc_bytes)
|
||||
|
||||
# Apply results
|
||||
# Execute and apply results
|
||||
result = _run_pcode(fn, op_cls, op, s0, s1, s2, d0, st.scc, st.vcc, 0, st.exec_mask, 0)
|
||||
if sdst is not None:
|
||||
if result.get('d0_64'):
|
||||
st.wsgpr64(sdst, result['d0'])
|
||||
else:
|
||||
st.wsgpr(sdst, result['d0'])
|
||||
if 'scc' in result: st.scc = result['scc']
|
||||
if result.get('d0_64'): st.wsgpr64(sdst, result['d0'])
|
||||
else: st.wsgpr(sdst, result['d0'])
|
||||
st.scc = result['scc']
|
||||
if 'exec' in result: st.exec_mask = result['exec']
|
||||
if 'new_pc' in result:
|
||||
# Convert absolute byte address to word delta
|
||||
# new_pc is where we want to go, st.pc is current position, inst._words will be added after
|
||||
new_pc_words = result['new_pc'] // 4
|
||||
return new_pc_words - st.pc - 1 # -1 because emulator adds inst_words (1 for scalar)
|
||||
return 0
|
||||
|
||||
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = None) -> None:
|
||||
|
|
@ -304,19 +346,22 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
|||
sx0, sx1 = st.rsrc(inst.srcx0, lane), V[inst.vsrcx1]
|
||||
sy0, sy1 = st.rsrc(inst.srcy0, lane), V[inst.vsrcy1]
|
||||
dx0, dy0 = V[inst.vdstx], V[vdsty]
|
||||
# Execute X op
|
||||
res_x = None
|
||||
# FMAAK/FMAMK in VOPD use literal as S2
|
||||
literal = getattr(inst, '_literal', None) or 0
|
||||
res_x = res_y = None
|
||||
if (op_x := _VOPD_TO_VOP.get(inst.opx)):
|
||||
if (fn_x := compiled.get(type(op_x), {}).get(op_x)):
|
||||
res_x = fn_x(sx0, sx1, 0, dx0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})
|
||||
# Execute Y op
|
||||
res_y = None
|
||||
if (fn := compiled.get(type(op_x), {}).get(op_x)):
|
||||
# opx 1=FMAMK, 2=FMAAK use literal
|
||||
sx2 = literal if inst.opx in (VOPDOp.V_DUAL_FMAMK_F32, VOPDOp.V_DUAL_FMAAK_F32) else 0
|
||||
res_x = _run_pcode(fn, type(op_x), op_x, sx0, sx1, sx2, dx0, st.scc, st.vcc, lane, st.exec_mask, 0)
|
||||
if (op_y := _VOPD_TO_VOP.get(inst.opy)):
|
||||
if (fn_y := compiled.get(type(op_y), {}).get(op_y)):
|
||||
res_y = fn_y(sy0, sy1, 0, dy0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})
|
||||
if (fn := compiled.get(type(op_y), {}).get(op_y)):
|
||||
# opy 1=FMAMK, 2=FMAAK use literal
|
||||
sy2 = literal if inst.opy in (VOPDOp.V_DUAL_FMAMK_F32, VOPDOp.V_DUAL_FMAAK_F32) else 0
|
||||
res_y = _run_pcode(fn, type(op_y), op_y, sy0, sy1, sy2, dy0, st.scc, st.vcc, lane, st.exec_mask, 0)
|
||||
# Write results after both ops complete
|
||||
if res_x is not None: V[inst.vdstx] = res_x['d0']
|
||||
if res_y is not None: V[vdsty] = res_y['d0']
|
||||
if res_x: V[inst.vdstx] = res_x['d0']
|
||||
if res_y: V[vdsty] = res_y['d0']
|
||||
return
|
||||
|
||||
# VOP3SD: has extra scalar dest for carry output
|
||||
|
|
@ -324,30 +369,17 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
|||
op = VOP3SDOp(inst.op)
|
||||
fn = compiled.get(VOP3SDOp, {}).get(op)
|
||||
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
# VOP3SD has both 32-bit ops (V_ADD_CO_CI_U32, etc.) and 64-bit ops (V_DIV_SCALE_F64, V_MAD_U64_U32, etc.)
|
||||
div_scale_64_ops = (VOP3SDOp.V_DIV_SCALE_F64,)
|
||||
s0, s1, s2 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane)
|
||||
# For 64-bit src2 ops (V_MAD_U64_U32, V_MAD_I64_I32), read from consecutive registers
|
||||
mad64_ops = (VOP3SDOp.V_MAD_U64_U32, VOP3SDOp.V_MAD_I64_I32)
|
||||
if op in div_scale_64_ops:
|
||||
# V_DIV_SCALE_F64: all sources are 64-bit
|
||||
s0, s1, s2 = st.rsrc64(inst.src0, lane), st.rsrc64(inst.src1, lane), st.rsrc64(inst.src2, lane)
|
||||
elif op in mad64_ops:
|
||||
# V_MAD_U64_U32, V_MAD_I64_I32: src0/src1 are 32-bit, src2 is 64-bit
|
||||
s0, s1 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane)
|
||||
if inst.src2 >= 256: # VGPR
|
||||
s2 = V[inst.src2 - 256] | (V[inst.src2 - 256 + 1] << 32)
|
||||
else: # SGPR - read 64-bit from consecutive SGPRs
|
||||
s2 = st.rsgpr64(inst.src2)
|
||||
else:
|
||||
# Default: 32-bit sources
|
||||
s0, s1, s2 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane)
|
||||
if op in mad64_ops:
|
||||
s2 = (V[inst.src2 - 256] | (V[inst.src2 - 256 + 1] << 32)) if inst.src2 >= 256 else st.rsgpr64(inst.src2)
|
||||
d0 = V[inst.vdst]
|
||||
# For carry-in operations (V_*_CO_CI_*), src2 register contains the carry bitmask (not VCC).
|
||||
# For carry-in ops (V_*_CO_CI_*), src2 register contains the carry bitmask (not VCC).
|
||||
# The pseudocode uses VCC but in VOP3SD encoding, the actual carry source is inst.src2.
|
||||
# We pass the src2 register value as 'vcc' to the interpreter so it reads the correct carry.
|
||||
carry_ops = (VOP3SDOp.V_ADD_CO_CI_U32, VOP3SDOp.V_SUB_CO_CI_U32, VOP3SDOp.V_SUBREV_CO_CI_U32)
|
||||
vcc_for_exec = st.rsgpr64(inst.src2) if op in carry_ops else st.vcc
|
||||
result = fn(s0, s1, s2, d0, st.scc, vcc_for_exec, lane, st.exec_mask, st.literal, None, {})
|
||||
# Write result - handle 64-bit destinations
|
||||
result = _run_pcode(fn, VOP3SDOp, op, s0, s1, s2, d0, st.scc, vcc_for_exec, lane, st.exec_mask, inst.vdst)
|
||||
if result.get('d0_64'):
|
||||
V[inst.vdst] = result['d0'] & 0xffffffff
|
||||
V[inst.vdst + 1] = (result['d0'] >> 32) & 0xffffffff
|
||||
|
|
@ -360,37 +392,54 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
|||
|
||||
|
||||
# Get op enum and sources (None means "no source" for that operand)
|
||||
# vop1_dst_hi/vop2_dst_hi: for VOP1/VOP2 16-bit dst ops, bit 7 of vdst indicates .h (high 16-bit) destination
|
||||
vop1_dst_hi, vop2_dst_hi = False, False
|
||||
if inst_type is VOP1:
|
||||
if inst.op == VOP1Op.V_NOP: return
|
||||
op_cls, op, src0, src1, src2 = VOP1Op, VOP1Op(inst.op), inst.src0, None, None
|
||||
# For 16-bit dst ops, vdst encodes .h in bit 7
|
||||
if op in _VOP1_16BIT_DST_OPS:
|
||||
vop1_dst_hi = (inst.vdst & 0x80) != 0
|
||||
vdst = inst.vdst & 0x7f
|
||||
else:
|
||||
vdst = inst.vdst
|
||||
# V_READFIRSTLANE_B32: read from first active lane's VGPR -> SGPR (not in pseudocode - needs cross-lane access)
|
||||
if inst.op == VOP1Op.V_READFIRSTLANE_B32:
|
||||
first_lane = (st.exec_mask & -st.exec_mask).bit_length() - 1 if st.exec_mask else 0
|
||||
vgpr_idx = inst.src0 - 256 if inst.src0 >= 256 else inst.src0 # VGPR index
|
||||
st.wsgpr(inst.vdst, st.vgpr[first_lane][vgpr_idx])
|
||||
return
|
||||
op_cls, op, src0, src1, src2, vdst = VOP1Op, VOP1Op(inst.op), inst.src0, None, None, inst.vdst
|
||||
elif inst_type is VOP2:
|
||||
op_cls, op, src0, src1, src2 = VOP2Op, VOP2Op(inst.op), inst.src0, inst.vsrc1 + 256, None
|
||||
# For 16-bit dst ops, vdst encodes .h in bit 7
|
||||
if op in _VOP2_16BIT_OPS:
|
||||
vop2_dst_hi = (inst.vdst & 0x80) != 0
|
||||
vdst = inst.vdst & 0x7f
|
||||
else:
|
||||
vdst = inst.vdst
|
||||
op_cls, op = VOP2Op, VOP2Op(inst.op)
|
||||
# FMAAK/FMAMK use inline literal constant as S2
|
||||
literal = getattr(inst, '_literal', None)
|
||||
src0, src1, src2, vdst = inst.src0, inst.vsrc1 + 256, literal, inst.vdst
|
||||
elif inst_type is VOP3:
|
||||
# VOP3 ops 0-255 are VOPC comparisons encoded as VOP3 (use VOPCOp pseudocode)
|
||||
if inst.op < 256:
|
||||
op_cls, op, src0, src1, src2, vdst = VOPCOp, VOPCOp(inst.op), inst.src0, inst.src1, None, inst.vdst
|
||||
else:
|
||||
op_cls, op, src0, src1, src2, vdst = VOP3Op, VOP3Op(inst.op), inst.src0, inst.src1, inst.src2, inst.vdst
|
||||
# V_READFIRSTLANE_B32 in VOP3 encoding - same as VOP1 but with VOP3 format
|
||||
if op == VOP3Op.V_READFIRSTLANE_B32:
|
||||
first_lane = (st.exec_mask & -st.exec_mask).bit_length() - 1 if st.exec_mask else 0
|
||||
vgpr_idx = inst.src0 - 256 if inst.src0 >= 256 else inst.src0
|
||||
st.wsgpr(inst.vdst, st.vgpr[first_lane][vgpr_idx])
|
||||
return
|
||||
# V_READLANE_B32: read from specific lane's VGPR -> SGPR (lane specified in src1)
|
||||
if op == VOP3Op.V_READLANE_B32:
|
||||
read_lane = st.rsrc(inst.src1, lane) & 0x1f # Lane to read from (5 bits)
|
||||
vgpr_idx = inst.src0 - 256 if inst.src0 >= 256 else inst.src0
|
||||
st.wsgpr(inst.vdst, st.vgpr[read_lane][vgpr_idx])
|
||||
return
|
||||
# V_PERM_B32: byte permutation - not in pseudocode PDF, implement directly
|
||||
# D0[byte_i] = selector[byte_i] < 8 ? {src0, src1}[selector[byte_i]] : (selector[byte_i] >= 0xD ? 0xFF : 0x00)
|
||||
if op == VOP3Op.V_PERM_B32:
|
||||
s0, s1, s2 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane)
|
||||
# Combine src1 and src0 into 8-byte value: src1 is bytes 0-3, src0 is bytes 4-7
|
||||
combined = (s1 & 0xffffffff) | ((s0 & 0xffffffff) << 32)
|
||||
result = 0
|
||||
for i in range(4): # 4 result bytes
|
||||
sel = (s2 >> (i * 8)) & 0xff # byte selector for this position
|
||||
if sel <= 7: result |= (((combined >> (sel * 8)) & 0xff) << (i * 8)) # select byte from combined
|
||||
elif sel >= 0xd: result |= (0xff << (i * 8)) # 0xD-0xF: constant 0xFF
|
||||
# else 0x8-0xC: constant 0x00 (already 0)
|
||||
V[vdst] = result & 0xffffffff
|
||||
return
|
||||
elif inst_type is VOPC:
|
||||
op = VOPCOp(inst.op)
|
||||
# For 16-bit VOPC, vsrc1 uses same encoding as VOP2 16-bit: bit 7 selects hi(1) or lo(0) half
|
||||
# vsrc1 field is 8 bits: [6:0] = VGPR index, [7] = hi flag
|
||||
src1 = inst.vsrc1 + 256 # convert to standard VGPR encoding (256 + vgpr_idx)
|
||||
op_cls, src0, src2, vdst = VOPCOp, inst.src0, None, VCC_LO
|
||||
op_cls, op, src0, src1, src2, vdst = VOPCOp, VOPCOp(inst.op), inst.src0, inst.vsrc1 + 256, None, VCC_LO
|
||||
elif inst_type is VOP3P:
|
||||
# VOP3P: Packed 16-bit operations using compiled functions
|
||||
op = VOP3POp(inst.op)
|
||||
|
|
@ -399,44 +448,26 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
|||
if lane == 0: # Only execute once per wave, write results for all lanes
|
||||
exec_wmma(st, inst, op)
|
||||
return
|
||||
# V_FMA_MIX: Mixed precision FMA - inputs can be f16 or f32 controlled by opsel_hi/opsel_hi2
|
||||
# opsel_hi[0]: src0 is f32 (0) or f16 from hi bits (1)
|
||||
# opsel_hi[1]: src1 is f32 (0) or f16 from hi bits (1)
|
||||
# opsel_hi2: src2 is f32 (0) or f16 from hi bits (1)
|
||||
# opsel[i]: when source is f16, use lo (0) or hi (1) 16 bits - BUT for V_FMA_MIX, opsel selects lo/hi when opsel_hi=1
|
||||
# neg_hi[i]: abs modifier for source i (reuses neg_hi field for abs in V_FMA_MIX)
|
||||
# V_FMA_MIX: Mixed precision FMA - inputs can be f16 or f32 controlled by opsel
|
||||
if op in (VOP3POp.V_FMA_MIX_F32, VOP3POp.V_FMA_MIXLO_F16, VOP3POp.V_FMA_MIXHI_F16):
|
||||
opsel = getattr(inst, 'opsel', 0)
|
||||
opsel_hi = getattr(inst, 'opsel_hi', 0)
|
||||
opsel_hi2 = getattr(inst, 'opsel_hi2', 0)
|
||||
neg = getattr(inst, 'neg', 0)
|
||||
abs_ = getattr(inst, 'neg_hi', 0) # neg_hi field is reused as abs for V_FMA_MIX
|
||||
neg_hi = getattr(inst, 'neg_hi', 0)
|
||||
vdst = inst.vdst
|
||||
# Read raw 32-bit values
|
||||
# Read raw 32-bit values - for V_FMA_MIX, sources can be either f32 or f16
|
||||
s0_raw = st.rsrc(inst.src0, lane)
|
||||
s1_raw = st.rsrc(inst.src1, lane)
|
||||
s2_raw = st.rsrc(inst.src2, lane) if inst.src2 is not None else 0
|
||||
# Decode sources based on opsel_hi (controls f32 vs f16) and opsel (controls which half for f16)
|
||||
# src0: opsel_hi[0]=1 means f16, opsel[0] selects hi(1) or lo(0) half
|
||||
if opsel_hi & 1:
|
||||
s0 = _f16((s0_raw >> 16) & 0xffff) if (opsel & 1) else _f16(s0_raw & 0xffff)
|
||||
else:
|
||||
s0 = _f32(s0_raw)
|
||||
# src1: opsel_hi[1]=1 means f16, opsel[1] selects hi(1) or lo(0) half
|
||||
if opsel_hi & 2:
|
||||
s1 = _f16((s1_raw >> 16) & 0xffff) if (opsel & 2) else _f16(s1_raw & 0xffff)
|
||||
else:
|
||||
s1 = _f32(s1_raw)
|
||||
# src2: opsel_hi2=1 means f16, opsel[2] selects hi(1) or lo(0) half
|
||||
if opsel_hi2:
|
||||
s2 = _f16((s2_raw >> 16) & 0xffff) if (opsel & 4) else _f16(s2_raw & 0xffff)
|
||||
else:
|
||||
s2 = _f32(s2_raw)
|
||||
# Apply abs modifiers (abs_ field reuses neg_hi position)
|
||||
if abs_ & 1: s0 = abs(s0)
|
||||
if abs_ & 2: s1 = abs(s1)
|
||||
if abs_ & 4: s2 = abs(s2)
|
||||
# Apply neg modifiers
|
||||
# opsel[i]=0: use as f32, opsel[i]=1: use hi f16 as f32
|
||||
# For src0: opsel[0], for src1: opsel[1], for src2: opsel[2]
|
||||
if opsel & 1: s0 = _f16((s0_raw >> 16) & 0xffff) # hi f16 -> f32
|
||||
else: s0 = _f32(s0_raw) # use as f32
|
||||
if opsel & 2: s1 = _f16((s1_raw >> 16) & 0xffff)
|
||||
else: s1 = _f32(s1_raw)
|
||||
if opsel & 4: s2 = _f16((s2_raw >> 16) & 0xffff)
|
||||
else: s2 = _f32(s2_raw)
|
||||
# Apply neg modifiers (for f32 values)
|
||||
if neg & 1: s0 = -s0
|
||||
if neg & 2: s1 = -s1
|
||||
if neg & 4: s2 = -s2
|
||||
|
|
@ -485,10 +516,10 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
|||
s0 = (s0_hi << 16) | s0_lo
|
||||
s1 = (s1_hi << 16) | s1_lo
|
||||
s2 = (s2_hi << 16) | s2_lo
|
||||
op_cls, vdst = VOP3POp, inst.vdst
|
||||
fn = compiled.get(op_cls, {}).get(op)
|
||||
vdst = inst.vdst
|
||||
fn = compiled.get(VOP3POp, {}).get(op)
|
||||
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
result = fn(s0, s1, s2, 0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})
|
||||
result = _run_pcode(fn, VOP3POp, op, s0, s1, s2, 0, st.scc, st.vcc, lane, st.exec_mask, vdst)
|
||||
st.vgpr[lane][vdst] = result['d0'] & 0xffffffff
|
||||
return
|
||||
else: raise NotImplementedError(f"Unknown vector type {inst_type}")
|
||||
|
|
@ -510,17 +541,13 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
|||
|
||||
# Determine if sources are 64-bit based on instruction type
|
||||
# For 64-bit shift ops: src0 is 32-bit (shift amount), src1 is 64-bit (value to shift)
|
||||
# For V_LDEXP_F64: src0 is 64-bit float, src1 is 32-bit integer exponent
|
||||
# For most other _B64/_I64/_U64/_F64 ops: all sources are 64-bit
|
||||
is_64bit_op = op.name.endswith(('_B64', '_I64', '_U64', '_F64'))
|
||||
# V_LDEXP_F64, V_TRIG_PREOP_F64, V_CMP_CLASS_F64, V_CMPX_CLASS_F64: src0 is 64-bit, src1 is 32-bit
|
||||
is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64, VOP3Op.V_TRIG_PREOP_F64, VOP3Op.V_CMP_CLASS_F64, VOP3Op.V_CMPX_CLASS_F64,
|
||||
VOPCOp.V_CMP_CLASS_F64, VOPCOp.V_CMPX_CLASS_F64)
|
||||
is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64,)
|
||||
is_shift_64 = op in (VOP3Op.V_LSHLREV_B64, VOP3Op.V_LSHRREV_B64, VOP3Op.V_ASHRREV_I64)
|
||||
# 16-bit source ops: use precomputed sets instead of string checks
|
||||
# Note: must check op_cls to avoid cross-enum value collisions
|
||||
is_16bit_src = op_cls is VOP3Op and op in _VOP3_16BIT_OPS and op not in _CVT_32_64_SRC_OPS
|
||||
# VOP2 16-bit ops use f16 inline constants for src0 (vsrc1 is always a VGPR, no inline constants)
|
||||
is_vop2_16bit = op_cls is VOP2Op and op in _VOP2_16BIT_OPS
|
||||
is_vop2_16bit = op_cls is VOP2Op and op in _VOP2_16BIT_OPS # VOP2 16-bit ops use f16 inline constants
|
||||
|
||||
if is_shift_64:
|
||||
s0 = mod_src(st.rsrc(src0, lane), 0) # shift amount is 32-bit
|
||||
|
|
@ -528,107 +555,37 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
|||
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
|
||||
elif is_ldexp_64:
|
||||
s0 = mod_src64(st.rsrc64(src0, lane), 0) # mantissa is 64-bit float
|
||||
# src1 is 32-bit int. For 64-bit ops (like V_CMP_CLASS_F64), the literal is stored shifted left by 32.
|
||||
# For V_LDEXP_F64/V_TRIG_PREOP_F64, _is_64bit_op() returns False so literal is stored as-is.
|
||||
s1_raw = st.rsrc(src1, lane) if src1 is not None else 0
|
||||
# Only shift if src1 is literal AND this is a true 64-bit op (V_CMP_CLASS ops, not LDEXP/TRIG_PREOP)
|
||||
is_class_op = op in (VOP3Op.V_CMP_CLASS_F64, VOP3Op.V_CMPX_CLASS_F64, VOPCOp.V_CMP_CLASS_F64, VOPCOp.V_CMPX_CLASS_F64)
|
||||
s1 = mod_src((s1_raw >> 32) if src1 == 255 and is_class_op else s1_raw, 1)
|
||||
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0 # exponent is 32-bit int
|
||||
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
|
||||
elif is_64bit_op:
|
||||
# 64-bit ops: apply neg/abs modifiers using f64 interpretation for float ops
|
||||
s0 = mod_src64(st.rsrc64(src0, lane), 0)
|
||||
s1 = mod_src64(st.rsrc64(src1, lane), 1) if src1 is not None else 0
|
||||
s2 = mod_src64(st.rsrc64(src2, lane), 2) if src2 is not None else 0
|
||||
elif is_16bit_src:
|
||||
# For 16-bit source ops, opsel bits select which half to use
|
||||
# Inline constants (128-254) must use f16 encoding, not f32
|
||||
def rsrc_16bit(src, lane): return st.rsrc_f16(src, lane) if 128 <= src < 255 else st.rsrc(src, lane)
|
||||
s0_raw = rsrc_16bit(src0, lane)
|
||||
s1_raw = rsrc_16bit(src1, lane) if src1 is not None else 0
|
||||
s2_raw = rsrc_16bit(src2, lane) if src2 is not None else 0
|
||||
# opsel[0] selects hi(1) or lo(0) for src0, opsel[1] for src1, opsel[2] for src2
|
||||
s0_raw = mod_src(st.rsrc(src0, lane), 0)
|
||||
s1_raw = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
|
||||
s2_raw = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
|
||||
s0 = ((s0_raw >> 16) & 0xffff) if (opsel & 1) else (s0_raw & 0xffff)
|
||||
s1 = ((s1_raw >> 16) & 0xffff) if (opsel & 2) else (s1_raw & 0xffff)
|
||||
s2 = ((s2_raw >> 16) & 0xffff) if (opsel & 4) else (s2_raw & 0xffff)
|
||||
# Apply abs/neg modifiers as f16 operations (toggle sign bit 15)
|
||||
if abs_ & 1: s0 &= 0x7fff
|
||||
if abs_ & 2: s1 &= 0x7fff
|
||||
if abs_ & 4: s2 &= 0x7fff
|
||||
if neg & 1: s0 ^= 0x8000
|
||||
if neg & 2: s1 ^= 0x8000
|
||||
if neg & 4: s2 ^= 0x8000
|
||||
elif is_vop2_16bit:
|
||||
# VOP2 16-bit ops: src0 uses f16 inline constants, or VGPR where v128+ = hi half of v0-v127
|
||||
# RDNA3 encoding: for VGPRs, bit 7 of VGPR index (src0-256) selects hi(1) or lo(0) half
|
||||
if src0 >= 256: # VGPR
|
||||
src0_hi = (src0 - 256) & 0x80 != 0
|
||||
src0_masked = ((src0 - 256) & 0x7f) + 256 # mask out hi bit to get actual VGPR
|
||||
s0_raw = mod_src(st.rsrc(src0_masked, lane), 0)
|
||||
s0 = ((s0_raw >> 16) & 0xffff) if src0_hi else (s0_raw & 0xffff)
|
||||
else: # SGPR or inline constant
|
||||
s0_raw = mod_src(st.rsrc_f16(src0, lane), 0)
|
||||
s0 = s0_raw & 0xffff
|
||||
# vsrc1: .h suffix encoded in bit 7 of VGPR index (src1 = 256 + vgpr_idx + 0x80 if hi)
|
||||
if src1 is not None:
|
||||
src1_hi = (src1 - 256) & 0x80 != 0
|
||||
src1_masked = ((src1 - 256) & 0x7f) + 256
|
||||
s1_raw = mod_src(st.rsrc(src1_masked, lane), 1)
|
||||
s1 = ((s1_raw >> 16) & 0xffff) if src1_hi else (s1_raw & 0xffff)
|
||||
else:
|
||||
s1 = 0
|
||||
s0 = mod_src(st.rsrc_f16(src0, lane), 0)
|
||||
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
|
||||
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
|
||||
elif op_cls is VOP1Op and op in _VOP1_16BIT_SRC_OPS:
|
||||
# VOP1 16-bit source ops: .h encoded in bit 7 of VGPR index (src0 >= 384 means hi half)
|
||||
# For VGPRs: src0 = 256 + vgpr_idx + (0x80 if hi else 0), so bit 7 of (src0-256) is the hi flag
|
||||
src0_hi = src0 >= 256 and ((src0 - 256) & 0x80) != 0
|
||||
src0_masked = ((src0 - 256) & 0x7f) + 256 if src0 >= 256 else src0 # mask out hi bit for VGPR
|
||||
s0_raw = mod_src(st.rsrc(src0_masked, lane), 0)
|
||||
s0 = ((s0_raw >> 16) & 0xffff) if src0_hi else (s0_raw & 0xffff)
|
||||
s1, s2 = 0, 0
|
||||
elif op_cls is VOPCOp and op in _VOPC_16BIT_OPS:
|
||||
# VOPC 16-bit ops: src0 and vsrc1 use same encoding as VOP2 16-bit
|
||||
# For VGPRs, bit 7 of VGPR index selects hi(1) or lo(0) half
|
||||
if src0 >= 256: # VGPR
|
||||
src0_hi = (src0 - 256) & 0x80 != 0
|
||||
src0_masked = ((src0 - 256) & 0x7f) + 256
|
||||
s0_raw = mod_src(st.rsrc(src0_masked, lane), 0)
|
||||
s0 = ((s0_raw >> 16) & 0xffff) if src0_hi else (s0_raw & 0xffff)
|
||||
else: # SGPR or inline constant
|
||||
s0_raw = mod_src(st.rsrc_f16(src0, lane), 0)
|
||||
s0 = s0_raw & 0xffff
|
||||
# vsrc1: bit 7 of VGPR index selects hi(1) or lo(0) half
|
||||
if src1 is not None:
|
||||
if src1 >= 256: # VGPR - use hi/lo encoding
|
||||
src1_hi = (src1 - 256) & 0x80 != 0
|
||||
src1_masked = ((src1 - 256) & 0x7f) + 256
|
||||
s1_raw = mod_src(st.rsrc(src1_masked, lane), 1)
|
||||
s1 = ((s1_raw >> 16) & 0xffff) if src1_hi else (s1_raw & 0xffff)
|
||||
else: # SGPR or inline constant - read as 32-bit, use low 16 bits
|
||||
s1_raw = mod_src(st.rsrc(src1, lane), 1)
|
||||
s1 = s1_raw & 0xffffffff # V_CMP_CLASS uses full 32-bit mask
|
||||
else:
|
||||
s1 = 0
|
||||
s2 = 0
|
||||
else:
|
||||
s0 = mod_src(st.rsrc(src0, lane), 0)
|
||||
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
|
||||
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
|
||||
# For VOP2 16-bit ops (like V_FMAC_F16), the destination is used as an accumulator.
|
||||
# The pseudocode reads D0.f16 from low 16 bits, so we need to shift hi->lo when vop2_dst_hi is True.
|
||||
if is_vop2_16bit:
|
||||
d0 = ((V[vdst] >> 16) & 0xffff) if vop2_dst_hi else (V[vdst] & 0xffff)
|
||||
else:
|
||||
d0 = V[vdst] if not is_64bit_op else (V[vdst] | (V[vdst + 1] << 32))
|
||||
# src2 can be a register index OR a raw literal value (for FMAAK/FMAMK)
|
||||
# If src2 > 511, it's a raw literal value, not a register index
|
||||
s2 = src2 if src2 is not None and src2 > 511 else (mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0)
|
||||
d0 = V[vdst] if not is_64bit_op else (V[vdst] | (V[vdst + 1] << 32))
|
||||
|
||||
# V_CNDMASK_B32/B16: VOP3 encoding uses src2 as mask (not VCC); VOP2 uses VCC implicitly
|
||||
# Pass the correct mask as vcc to the function so pseudocode VCC.u64[laneId] works correctly
|
||||
vcc_for_fn = st.rsgpr64(src2) if op in (VOP3Op.V_CNDMASK_B32, VOP3Op.V_CNDMASK_B16) and inst_type is VOP3 and src2 is not None and src2 < 256 else st.vcc
|
||||
# V_CNDMASK_B32: VOP3 encoding uses src2 as mask (not VCC); VOP2 uses VCC implicitly
|
||||
vcc_for_fn = st.rsgpr64(src2) if op in (VOP3Op.V_CNDMASK_B32,) and inst_type is VOP3 and src2 is not None and src2 < 256 else st.vcc
|
||||
|
||||
# Execute compiled function - pass src0_idx and vdst_idx for lane instructions
|
||||
# For VGPR access: src0 index is the VGPR number (src0 - 256 if VGPR, else src0 for SGPR)
|
||||
src0_idx = (src0 - 256) if src0 is not None and src0 >= 256 else (src0 if src0 is not None else 0)
|
||||
result = fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, st.literal, st.vgpr, {}, src0_idx, vdst)
|
||||
# Execute pseudocode
|
||||
result = _run_pcode(fn, op_cls, op, s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, vdst)
|
||||
|
||||
# Apply results
|
||||
if 'vgpr_write' in result:
|
||||
|
|
@ -636,8 +593,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
|||
wr_lane, wr_idx, wr_val = result['vgpr_write']
|
||||
st.vgpr[wr_lane][wr_idx] = wr_val
|
||||
if 'vcc_lane' in result:
|
||||
# VOP2 carry instructions (V_ADD_CO_CI_U32, V_SUB_CO_CI_U32, V_SUBREV_CO_CI_U32) write carry to VCC implicitly
|
||||
# VOPC and VOP3-encoded VOPC write to vdst (which is VCC_LO for VOPC, inst.sdst for VOP3)
|
||||
# VOP2 carry instructions write carry to VCC implicitly; VOPC writes to vdst
|
||||
vcc_dst = VCC_LO if op_cls is VOP2Op and op in (VOP2Op.V_ADD_CO_CI_U32, VOP2Op.V_SUB_CO_CI_U32, VOP2Op.V_SUBREV_CO_CI_U32) else vdst
|
||||
st.pend_sgpr_lane(vcc_dst, lane, result['vcc_lane'])
|
||||
if 'exec_lane' in result:
|
||||
|
|
@ -645,35 +601,18 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
|||
st.pend_sgpr_lane(EXEC_LO, lane, result['exec_lane'])
|
||||
if 'd0' in result and op_cls not in (VOPCOp,) and 'vgpr_write' not in result:
|
||||
# V_READFIRSTLANE_B32 and V_READLANE_B32 write to SGPR, not VGPR
|
||||
# V_WRITELANE_B32 uses vgpr_write for cross-lane writes, don't overwrite with d0
|
||||
writes_to_sgpr = op in (VOP1Op.V_READFIRSTLANE_B32,) or \
|
||||
(op_cls is VOP3Op and op in (VOP3Op.V_READFIRSTLANE_B32, VOP3Op.V_READLANE_B32))
|
||||
# Check for 16-bit destination ops (opsel[3] controls hi/lo write)
|
||||
# Must check op_cls to avoid cross-enum value collisions (e.g., VOP1Op.V_MOV_B32=1 vs VOP3Op.V_CMP_LT_F16=1)
|
||||
is_16bit_dst = (op_cls is VOP3Op and op in _VOP3_16BIT_DST_OPS) or (op_cls is VOP1Op and op in _VOP1_16BIT_DST_OPS)
|
||||
is_16bit_dst = op in _VOP3_16BIT_DST_OPS or op in _VOP1_16BIT_DST_OPS
|
||||
if writes_to_sgpr:
|
||||
st.wsgpr(vdst, result['d0'] & 0xffffffff)
|
||||
elif result.get('d0_64'):
|
||||
elif result.get('d0_64') or is_64bit_op:
|
||||
V[vdst] = result['d0'] & 0xffffffff
|
||||
V[vdst + 1] = (result['d0'] >> 32) & 0xffffffff
|
||||
elif is_16bit_dst and inst_type is VOP3:
|
||||
# VOP3 16-bit ops: opsel[3] (bit 3 of opsel field) controls hi/lo destination
|
||||
if opsel & 8: # opsel[3] = 1: write to high 16 bits
|
||||
V[vdst] = (V[vdst] & 0x0000ffff) | ((result['d0'] & 0xffff) << 16)
|
||||
else: # opsel[3] = 0: write to low 16 bits
|
||||
V[vdst] = (V[vdst] & 0xffff0000) | (result['d0'] & 0xffff)
|
||||
elif is_16bit_dst and inst_type is VOP1:
|
||||
# VOP1 16-bit ops: .h suffix encoded in bit 7 of vdst (extracted as vop1_dst_hi)
|
||||
if vop1_dst_hi: # .h: write to high 16 bits
|
||||
V[vdst] = (V[vdst] & 0x0000ffff) | ((result['d0'] & 0xffff) << 16)
|
||||
else: # .l: write to low 16 bits
|
||||
V[vdst] = (V[vdst] & 0xffff0000) | (result['d0'] & 0xffff)
|
||||
elif is_vop2_16bit:
|
||||
# VOP2 16-bit ops: .h suffix encoded in bit 7 of vdst (extracted as vop2_dst_hi)
|
||||
if vop2_dst_hi: # .h: write to high 16 bits
|
||||
V[vdst] = (V[vdst] & 0x0000ffff) | ((result['d0'] & 0xffff) << 16)
|
||||
else: # .l: write to low 16 bits
|
||||
V[vdst] = (V[vdst] & 0xffff0000) | (result['d0'] & 0xffff)
|
||||
# VOP3 16-bit ops: opsel[3] controls hi/lo destination
|
||||
if opsel & 8: V[vdst] = (V[vdst] & 0x0000ffff) | ((result['d0'] & 0xffff) << 16)
|
||||
else: V[vdst] = (V[vdst] & 0xffff0000) | (result['d0'] & 0xffff)
|
||||
else:
|
||||
V[vdst] = result['d0'] & 0xffffffff
|
||||
|
||||
|
|
|
|||
|
|
@ -35,18 +35,12 @@ def _isnan(x):
|
|||
try: return math.isnan(float(x))
|
||||
except (TypeError, ValueError): return False
|
||||
def _isquietnan(x):
|
||||
"""Check if x is a quiet NaN.
|
||||
f16: exponent=31, bit9=1, mantissa!=0
|
||||
f32: exponent=255, bit22=1, mantissa!=0
|
||||
f64: exponent=2047, bit51=1, mantissa!=0
|
||||
"""
|
||||
"""Check if x is a quiet NaN. For f32: exponent=255, bit22=1, mantissa!=0"""
|
||||
try:
|
||||
if not math.isnan(float(x)): return False
|
||||
# Get raw bits from TypedView or similar object with _reg attribute
|
||||
if hasattr(x, '_reg') and hasattr(x, '_bits'):
|
||||
bits = x._reg._val & ((1 << x._bits) - 1)
|
||||
if x._bits == 16:
|
||||
return ((bits >> 10) & 0x1f) == 31 and ((bits >> 9) & 1) == 1 and (bits & 0x3ff) != 0
|
||||
if x._bits == 32:
|
||||
return ((bits >> 23) & 0xff) == 255 and ((bits >> 22) & 1) == 1 and (bits & 0x7fffff) != 0
|
||||
if x._bits == 64:
|
||||
|
|
@ -54,18 +48,12 @@ def _isquietnan(x):
|
|||
return True # Default to quiet NaN if we can't determine bit pattern
|
||||
except (TypeError, ValueError): return False
|
||||
def _issignalnan(x):
|
||||
"""Check if x is a signaling NaN.
|
||||
f16: exponent=31, bit9=0, mantissa!=0
|
||||
f32: exponent=255, bit22=0, mantissa!=0
|
||||
f64: exponent=2047, bit51=0, mantissa!=0
|
||||
"""
|
||||
"""Check if x is a signaling NaN. For f32: exponent=255, bit22=0, mantissa!=0"""
|
||||
try:
|
||||
if not math.isnan(float(x)): return False
|
||||
# Get raw bits from TypedView or similar object with _reg attribute
|
||||
if hasattr(x, '_reg') and hasattr(x, '_bits'):
|
||||
bits = x._reg._val & ((1 << x._bits) - 1)
|
||||
if x._bits == 16:
|
||||
return ((bits >> 10) & 0x1f) == 31 and ((bits >> 9) & 1) == 0 and (bits & 0x3ff) != 0
|
||||
if x._bits == 32:
|
||||
return ((bits >> 23) & 0xff) == 255 and ((bits >> 22) & 1) == 0 and (bits & 0x7fffff) != 0
|
||||
if x._bits == 64:
|
||||
|
|
@ -85,11 +73,7 @@ def floor(x):
|
|||
def ceil(x):
|
||||
x = float(x)
|
||||
return x if math.isnan(x) or math.isinf(x) else float(math.ceil(x))
|
||||
class _SafeFloat(float):
|
||||
"""Float subclass that uses _div for division to handle 0/inf correctly."""
|
||||
def __truediv__(self, o): return _div(float(self), float(o))
|
||||
def __rtruediv__(self, o): return _div(float(o), float(self))
|
||||
def sqrt(x): return _SafeFloat(math.sqrt(x)) if x >= 0 else _SafeFloat(float("nan"))
|
||||
def sqrt(x): return math.sqrt(x) if x >= 0 else float("nan")
|
||||
def log2(x): return math.log2(x) if x > 0 else (float("-inf") if x == 0 else float("nan"))
|
||||
i32_to_f32 = u32_to_f32 = i32_to_f64 = u32_to_f64 = f32_to_f64 = f64_to_f32 = float
|
||||
def f32_to_i32(f):
|
||||
|
|
@ -123,10 +107,7 @@ def u4_to_u32(v): return int(v) & 0xf
|
|||
def _sign(f): return 1 if math.copysign(1.0, f) < 0 else 0
|
||||
def _mantissa_f32(f): return struct.unpack("<I", struct.pack("<f", f))[0] & 0x7fffff if not (math.isinf(f) or math.isnan(f)) else 0
|
||||
def _ldexp(m, e): return math.ldexp(m, e)
|
||||
def isEven(x):
|
||||
x = float(x)
|
||||
if math.isinf(x) or math.isnan(x): return False
|
||||
return int(x) % 2 == 0
|
||||
def isEven(x): return int(x) % 2 == 0
|
||||
def fract(x): return x - math.floor(x)
|
||||
PI = math.pi
|
||||
def sin(x):
|
||||
|
|
@ -280,7 +261,7 @@ def f32_to_u8(f): return max(0, min(255, int(f))) if not math.isnan(f) else 0
|
|||
def mantissa(f):
|
||||
if f == 0.0 or math.isinf(f) or math.isnan(f): return f
|
||||
m, _ = math.frexp(f)
|
||||
return m # AMD V_FREXP_MANT returns mantissa in [0.5, 1.0) range
|
||||
return math.copysign(m * 2.0, f)
|
||||
def signext_from_bit(val, bit):
|
||||
bit = int(bit)
|
||||
if bit == 0: return 0
|
||||
|
|
@ -301,7 +282,6 @@ __all__ = [
|
|||
# Constants
|
||||
'WAVE32', 'WAVE64', 'MASK32', 'MASK64', 'WAVE_MODE', 'DENORM', 'OVERFLOW_F32', 'UNDERFLOW_F32',
|
||||
'OVERFLOW_F64', 'UNDERFLOW_F64', 'MAX_FLOAT_F32', 'ROUND_MODE', 'cvtToQuietNAN', 'DST', 'INF', 'PI',
|
||||
'TWO_OVER_PI_1201',
|
||||
# Aliases for pseudocode
|
||||
's_ff1_i32_b32', 's_ff1_i32_b64', 'GT_NEG_ZERO', 'LT_NEG_ZERO',
|
||||
'isNAN', 'isQuietNAN', 'isSignalNAN', 'fma', 'ldexp', 'sign', 'exponent', 'F', 'signext',
|
||||
|
|
@ -342,7 +322,7 @@ def F(x):
|
|||
if isinstance(x, int): return _f32(x) # int -> interpret as f32 bits
|
||||
if isinstance(x, TypedView): return x # preserve TypedView for bit-pattern checks
|
||||
return float(x) # already a float or float-like
|
||||
signext = lambda x: int(x) # sign-extend to full width - already handled by Python's arbitrary precision ints
|
||||
signext = lambda x: x
|
||||
pack = lambda hi, lo: ((int(hi) & 0xffff) << 16) | (int(lo) & 0xffff)
|
||||
pack32 = lambda hi, lo: ((int(hi) & 0xffffffff) << 32) | (int(lo) & 0xffffffff)
|
||||
_pack, _pack32 = pack, pack32 # Aliases for internal use
|
||||
|
|
@ -360,14 +340,12 @@ class _Inf:
|
|||
f16 = f32 = f64 = float('inf')
|
||||
def __neg__(self): return _NegInf()
|
||||
def __pos__(self): return self
|
||||
def __float__(self): return float('inf')
|
||||
def __eq__(self, other): return float(other) == float('inf') if not isinstance(other, _NegInf) else False
|
||||
def __req__(self, other): return self.__eq__(other)
|
||||
class _NegInf:
|
||||
f16 = f32 = f64 = float('-inf')
|
||||
def __neg__(self): return _Inf()
|
||||
def __pos__(self): return self
|
||||
def __float__(self): return float('-inf')
|
||||
def __eq__(self, other): return float(other) == float('-inf') if not isinstance(other, _Inf) else False
|
||||
def __req__(self, other): return self.__eq__(other)
|
||||
INF = _Inf()
|
||||
|
|
@ -383,31 +361,6 @@ DST = None # Placeholder, will be set in context
|
|||
|
||||
MASK32, MASK64 = 0xffffffff, 0xffffffffffffffff
|
||||
|
||||
# 2/PI with 1201 bits of precision for V_TRIG_PREOP_F64
|
||||
# Computed as: int((2/pi) * 2^1201) - this is the fractional part of 2/pi scaled to integer
|
||||
# The MSB (bit 1200) corresponds to 2^0 position in the fraction 0.b1200 b1199 ... b1 b0
|
||||
_TWO_OVER_PI_1201_RAW = 0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6
|
||||
|
||||
class _BigInt:
|
||||
"""Wrapper for large integers that supports bit slicing [high:low]."""
|
||||
__slots__ = ('_val',)
|
||||
def __init__(self, val): self._val = val
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, slice):
|
||||
high, low = key.start, key.stop
|
||||
if high < low: high, low = low, high # Handle reversed slice
|
||||
mask = (1 << (high - low + 1)) - 1
|
||||
return (self._val >> low) & mask
|
||||
return (self._val >> key) & 1
|
||||
def __int__(self): return self._val
|
||||
def __index__(self): return self._val
|
||||
def __lshift__(self, n): return self._val << int(n)
|
||||
def __rshift__(self, n): return self._val >> int(n)
|
||||
def __and__(self, n): return self._val & int(n)
|
||||
def __or__(self, n): return self._val | int(n)
|
||||
|
||||
TWO_OVER_PI_1201 = _BigInt(_TWO_OVER_PI_1201_RAW)
|
||||
|
||||
class _WaveMode:
|
||||
IEEE = False
|
||||
WAVE_MODE = _WaveMode()
|
||||
|
|
@ -547,17 +500,6 @@ class TypedView:
|
|||
|
||||
def __bool__(s): return bool(int(s))
|
||||
|
||||
# Allow chained type access like jump_addr.i64 when jump_addr is already a TypedView
|
||||
# These just return self or convert appropriately
|
||||
@property
|
||||
def i64(s): return s if s._bits == 64 and s._signed else int(s)
|
||||
@property
|
||||
def u64(s): return s if s._bits == 64 and not s._signed else int(s) & MASK64
|
||||
@property
|
||||
def i32(s): return s if s._bits == 32 and s._signed else _sext(int(s) & MASK32, 32)
|
||||
@property
|
||||
def u32(s): return s if s._bits == 32 and not s._signed else int(s) & MASK32
|
||||
|
||||
class Reg:
|
||||
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works."""
|
||||
__slots__ = ('_val',)
|
||||
|
|
@ -581,7 +523,6 @@ class Reg:
|
|||
bf16 = property(lambda s: TypedView(s, 16, is_float=True, is_bf16=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _ibf16(float(v))) & 0xffff)))
|
||||
u8 = property(lambda s: TypedView(s, 8))
|
||||
i8 = property(lambda s: TypedView(s, 8, signed=True))
|
||||
u1 = property(lambda s: TypedView(s, 1)) # single bit
|
||||
|
||||
def __getitem__(s, key):
|
||||
if isinstance(key, slice): return SliceProxy(s, int(key.start), int(key.stop))
|
||||
|
|
@ -624,6 +565,19 @@ class Reg:
|
|||
def __eq__(s, o): return s._val == int(o)
|
||||
def __ne__(s, o): return s._val != int(o)
|
||||
|
||||
class ExecContext:
|
||||
"""Execution context for running compiled pseudocode strings (for testing)."""
|
||||
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=0xffffffff):
|
||||
self.S0, self.S1, self.S2, self.D0, self.D1 = Reg(s0), Reg(s1), Reg(s2), Reg(d0), Reg(0)
|
||||
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
|
||||
self.tmp, self.saveexec, self.laneId = Reg(0), Reg(exec_mask), lane
|
||||
def run(self, code: str):
|
||||
if not code.strip(): return
|
||||
code = code if '\n' in code else compile_pseudocode(code)
|
||||
exec(code, {**globals(), 'Reg': Reg, 'SliceProxy': SliceProxy}, self.__dict__)
|
||||
def result(self) -> dict:
|
||||
return {'d0': self.D0._val, 'd1': self.D1._val, 'scc': self.SCC._val & 1, 'vcc': self.VCC._val, 'exec': self.EXEC._val}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# COMPILER: pseudocode -> Python (minimal transforms)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
|
@ -642,7 +596,7 @@ def compile_pseudocode(pseudocode: str) -> str:
|
|||
joined_lines.append(line)
|
||||
|
||||
lines = []
|
||||
indent, need_pass, in_first_match_loop = 0, False, False
|
||||
indent, need_pass = 0, False
|
||||
for line in joined_lines:
|
||||
line = line.strip()
|
||||
if not line or line.startswith('//'): continue
|
||||
|
|
@ -671,14 +625,14 @@ def compile_pseudocode(pseudocode: str) -> str:
|
|||
elif line.startswith('endfor'):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
need_pass, in_first_match_loop = False, False
|
||||
need_pass = False
|
||||
elif line.startswith('declare '):
|
||||
pass
|
||||
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
|
||||
start, end = _expr(m[2].strip()), _expr(m[3].strip())
|
||||
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
|
||||
indent += 1
|
||||
need_pass, in_first_match_loop = True, True
|
||||
need_pass = True
|
||||
elif '=' in line and not line.startswith('=='):
|
||||
need_pass = False
|
||||
line = line.rstrip(';')
|
||||
|
|
@ -697,20 +651,20 @@ def compile_pseudocode(pseudocode: str) -> str:
|
|||
break
|
||||
else:
|
||||
lhs, rhs = line.split('=', 1)
|
||||
lhs_s, rhs_s = lhs.strip(), rhs.strip()
|
||||
stmt = _assign(lhs_s, _expr(rhs_s))
|
||||
# CLZ/CTZ pattern: assignment of loop var to tmp/D0.i32 in first-match loop needs break
|
||||
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
|
||||
stmt += "; break"
|
||||
lines.append(' ' * indent + stmt)
|
||||
lines.append(' ' * indent + _assign(lhs.strip(), _expr(rhs.strip())))
|
||||
# If we ended with a control statement that needs a body, add pass
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _assign(lhs: str, rhs: str) -> str:
|
||||
"""Generate assignment. Bare tmp/SCC/etc get wrapped in Reg()."""
|
||||
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
|
||||
return f"{lhs} = Reg({rhs})"
|
||||
"""Generate assignment. Outputs modify Reg in-place via ._val."""
|
||||
# Output registers and tmp: modify in-place so caller sees changes
|
||||
if lhs in ('SCC', 'VCC', 'EXEC', 'D0', 'D1', 'tmp'):
|
||||
return f"{lhs}._val = int({rhs})"
|
||||
# saveexec needs to be a new Reg for typed accessor access
|
||||
if lhs == 'saveexec':
|
||||
return f"{lhs} = Reg(int({rhs}))"
|
||||
# Other locals: natural style
|
||||
return f"{lhs} = {rhs}"
|
||||
|
||||
def _expr(e: str) -> str:
|
||||
|
|
@ -726,9 +680,6 @@ def _expr(e: str) -> str:
|
|||
return f'_pack({hi}, {lo})'
|
||||
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
|
||||
|
||||
# Special constant: 1201'B(2.0 / PI) -> TWO_OVER_PI_1201 (precomputed 1201-bit 2/pi)
|
||||
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
|
||||
|
||||
# Literals: 1'0U -> 0, 32'I(x) -> (x), B(x) -> (x)
|
||||
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
|
||||
e = re.sub(r"\d+'[FIBU]\(", "(", e)
|
||||
|
|
@ -795,50 +746,7 @@ def _expr(e: str) -> str:
|
|||
e = f'(({t}) if ({cond}) else ({f}))'
|
||||
return e
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# EXECUTION CONTEXT
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class ExecContext:
|
||||
"""Context for running compiled pseudocode."""
|
||||
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=MASK32, literal=0, vgprs=None, src0_idx=0, vdst_idx=0):
|
||||
self.S0, self.S1, self.S2 = Reg(s0), Reg(s1), Reg(s2)
|
||||
self.D0, self.D1 = Reg(d0), Reg(0)
|
||||
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
|
||||
self.tmp, self.saveexec = Reg(0), Reg(exec_mask)
|
||||
self.lane, self.laneId, self.literal = lane, lane, literal
|
||||
self.SIMM16, self.SIMM32 = Reg(literal), Reg(literal)
|
||||
self.VGPR = vgprs if vgprs is not None else {}
|
||||
self.SRC0, self.VDST = Reg(src0_idx), Reg(vdst_idx)
|
||||
|
||||
def run(self, code: str):
|
||||
"""Execute compiled code."""
|
||||
# Start with module globals (helpers, aliases), then add instance-specific bindings
|
||||
ns = dict(globals())
|
||||
ns.update({
|
||||
'S0': self.S0, 'S1': self.S1, 'S2': self.S2, 'D0': self.D0, 'D1': self.D1,
|
||||
'SCC': self.SCC, 'VCC': self.VCC, 'EXEC': self.EXEC,
|
||||
'EXEC_LO': SliceProxy(self.EXEC, 31, 0), 'EXEC_HI': SliceProxy(self.EXEC, 63, 32),
|
||||
'tmp': self.tmp, 'saveexec': self.saveexec,
|
||||
'lane': self.lane, 'laneId': self.laneId, 'literal': self.literal,
|
||||
'SIMM16': self.SIMM16, 'SIMM32': self.SIMM32,
|
||||
'VGPR': self.VGPR, 'SRC0': self.SRC0, 'VDST': self.VDST,
|
||||
})
|
||||
exec(code, ns)
|
||||
# Sync rebinds: if register was reassigned to new Reg or value, copy it back
|
||||
def _sync(ctx_reg, ns_val):
|
||||
if isinstance(ns_val, Reg): ctx_reg._val = ns_val._val
|
||||
else: ctx_reg._val = int(ns_val) & MASK64
|
||||
if ns.get('SCC') is not self.SCC: _sync(self.SCC, ns['SCC'])
|
||||
if ns.get('VCC') is not self.VCC: _sync(self.VCC, ns['VCC'])
|
||||
if ns.get('EXEC') is not self.EXEC: _sync(self.EXEC, ns['EXEC'])
|
||||
if ns.get('D0') is not self.D0: _sync(self.D0, ns['D0'])
|
||||
if ns.get('D1') is not self.D1: _sync(self.D1, ns['D1'])
|
||||
if ns.get('tmp') is not self.tmp: _sync(self.tmp, ns['tmp'])
|
||||
if ns.get('saveexec') is not self.saveexec: _sync(self.saveexec, ns['saveexec'])
|
||||
|
||||
def result(self) -> dict:
|
||||
return {"d0": self.D0._val, "scc": self.SCC._val & 1}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PDF EXTRACTION AND CODE GENERATION
|
||||
|
|
@ -849,14 +757,14 @@ INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
|||
|
||||
# Patterns that can't be handled by the DSL (require special handling in emu.py)
|
||||
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
|
||||
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
|
||||
'PC =', 'PC=', 'PC+', '= PC', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
|
||||
'CVT_OFF_TABLE', 'ThreadMask',
|
||||
'S1[i', 'C.i32', 'S[i]', 'in[',
|
||||
'S1[i', 'C.i32', 'S[i]', 'in[', '2.0 / PI',
|
||||
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF
|
||||
|
||||
def extract_pseudocode(text: str) -> str | None:
|
||||
"""Extract pseudocode from an instruction description snippet."""
|
||||
lines, result, depth, in_lambda = text.split('\n'), [], 0, 0
|
||||
lines, result, depth = text.split('\n'), [], 0
|
||||
for line in lines:
|
||||
s = line.strip()
|
||||
if not s: continue
|
||||
|
|
@ -865,17 +773,12 @@ def extract_pseudocode(text: str) -> str | None:
|
|||
# Skip document headers (RDNA or CDNA)
|
||||
if s.startswith('"RDNA') or s.startswith('AMD ') or s.startswith('CDNA'): continue
|
||||
if s.startswith('Notes') or s.startswith('Functional examples'): break
|
||||
# Track lambda definitions (e.g., BYTE_PERMUTE = lambda(data, sel) (...))
|
||||
if '= lambda(' in s: in_lambda += 1; continue
|
||||
if in_lambda > 0:
|
||||
if s.endswith(');'): in_lambda -= 1
|
||||
continue
|
||||
if s.startswith('if '): depth += 1
|
||||
elif s.startswith('endif'): depth = max(0, depth - 1)
|
||||
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
|
||||
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
|
||||
is_code = (
|
||||
any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =']) or
|
||||
any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =']) or
|
||||
any(p in s for p in ['D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
|
||||
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
|
||||
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s)
|
||||
|
|
@ -1017,8 +920,11 @@ from extra.assembly.amd.pcode import *
|
|||
|
||||
try:
|
||||
code = compile_pseudocode(pc)
|
||||
# NOTE: Do NOT add more code.replace() hacks here. Fix issues properly in the DSL
|
||||
# (compile_pseudocode, helper functions, or Reg/TypedView classes) instead.
|
||||
# CLZ/CTZ: The PDF pseudocode searches for the first 1 bit but doesn't break.
|
||||
# Hardware stops at first match. SOP1 uses tmp=i, VOP1/VOP3 use D0.i32=i
|
||||
if 'CLZ' in op.name or 'CTZ' in op.name:
|
||||
code = code.replace('tmp = Reg(i)', 'tmp = Reg(i); break')
|
||||
code = code.replace('D0.i32 = i', 'D0.i32 = i; break')
|
||||
# V_DIV_FMAS_F32/F64: PDF page 449 says 2^32/2^64 but hardware behavior is more complex.
|
||||
# The scale direction depends on S2 (the addend): if exponent(S2) > 127 (i.e., S2 >= 2.0),
|
||||
# scale by 2^+64 (to unscale a numerator that was scaled). Otherwise scale by 2^-64
|
||||
|
|
@ -1040,7 +946,7 @@ from extra.assembly.amd.pcode import *
|
|||
# Fix 1: Set VCC=1 when zero operands produce NaN
|
||||
code = code.replace(
|
||||
'D0.f32 = float("nan")',
|
||||
'VCC = Reg(0x1); D0.f32 = float("nan")')
|
||||
'VCC._val = 0x1; D0.f32 = float("nan")')
|
||||
# Fix 2: Denorm denom returns NaN. Must check this AFTER all VCC-setting logic runs.
|
||||
# Insert at end of all branches, before the final result is used
|
||||
code = code.replace(
|
||||
|
|
@ -1051,26 +957,26 @@ from extra.assembly.amd.pcode import *
|
|||
# Fix 3: Tiny numer should set VCC=1
|
||||
code = code.replace(
|
||||
'elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)',
|
||||
'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
|
||||
'elif exponent(S2.f32) <= 23:\n VCC._val = 0x1; D0.f32 = ldexp(S0.f32, 64)')
|
||||
# Fix 4: S2/S1 would be denorm - don't scale, just set VCC
|
||||
code = code.replace(
|
||||
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)',
|
||||
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
|
||||
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC._val = int(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)',
|
||||
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC._val = 0x1')
|
||||
if op.name == 'V_DIV_SCALE_F64':
|
||||
# Same fixes for f64 version
|
||||
code = code.replace(
|
||||
'D0.f64 = float("nan")',
|
||||
'VCC = Reg(0x1); D0.f64 = float("nan")')
|
||||
'VCC._val = 0x1; D0.f64 = float("nan")')
|
||||
code = code.replace(
|
||||
'elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)',
|
||||
'elif False:\n pass # denorm check moved to end')
|
||||
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
|
||||
code = code.replace(
|
||||
'elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)',
|
||||
'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
|
||||
'elif exponent(S2.f64) <= 52:\n VCC._val = 0x1; D0.f64 = ldexp(S0.f64, 128)')
|
||||
code = code.replace(
|
||||
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)',
|
||||
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
|
||||
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC._val = int(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)',
|
||||
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC._val = 0x1')
|
||||
# V_DIV_FIXUP_F32/F64: PDF doesn't check isNAN(S0), but hardware returns OVERFLOW if S0 is NaN.
|
||||
# When division fails (e.g., due to denorm denom), S0 becomes NaN, and fixup should return ±inf.
|
||||
if op.name == 'V_DIV_FIXUP_F32':
|
||||
|
|
@ -1081,86 +987,23 @@ from extra.assembly.amd.pcode import *
|
|||
code = code.replace(
|
||||
'D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
|
||||
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
|
||||
# V_TRIG_PREOP_F64: AMD pseudocode uses (x << shift) & mask but mask needs to extract TOP bits.
|
||||
# The PDF shows: result = 64'F((1201'B(2.0/PI)[1200:0] << shift) & 1201'0x1fffffffffffff)
|
||||
# Issues to fix:
|
||||
# 1. After left shift, the interesting bits are at the top, not bottom - need >> (1201-53)
|
||||
# 2. shift.u32 fails because shift is a plain int after * 53 - use int(shift)
|
||||
# 3. 64'F(...) means convert int to float (not interpret as bit pattern) - use float()
|
||||
if op.name == 'V_TRIG_PREOP_F64':
|
||||
code = code.replace(
|
||||
'result = F((TWO_OVER_PI_1201[1200 : 0] << shift.u32) & 0x1fffffffffffff)',
|
||||
'result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)')
|
||||
# Detect flags for result handling
|
||||
is_64 = any(p in pc for p in ['D0.u64', 'D0.b64', 'D0.f64', 'D0.i64', 'D1.u64', 'D1.b64', 'D1.f64', 'D1.i64'])
|
||||
has_d1 = '{ D1' in pc
|
||||
if has_d1: is_64 = True
|
||||
is_cmp = (cls_name == 'VOPCOp' or cls_name == 'VOP3Op') and 'D0.u64[laneId]' in pc
|
||||
is_cmpx = (cls_name == 'VOPCOp' or cls_name == 'VOP3Op') and 'EXEC.u64[laneId]' in pc # V_CMPX writes to EXEC per-lane
|
||||
# V_DIV_SCALE passes through S0 if no branch taken
|
||||
is_div_scale = 'DIV_SCALE' in op.name
|
||||
# VOP3SD instructions that write VCC per-lane (either via VCC.u64[laneId] or by setting VCC = 0/1)
|
||||
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
||||
# Instructions that use/modify PC
|
||||
has_pc = 'PC' in pc
|
||||
|
||||
# Generate function with indented body
|
||||
# SIMM16/SIMM32 (inline literal constants) are passed as S2
|
||||
code = code.replace('SIMM16', 'S2').replace('SIMM32', 'S2')
|
||||
# Generate function with standard signature
|
||||
fn_name = f"_{cls_name}_{op.name}"
|
||||
lines.append(f"def {fn_name}(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0):")
|
||||
# Add original pseudocode as comment
|
||||
lines.append(f"def {fn_name}(S0, S1, S2, D0, D1, SCC, VCC, EXEC, tmp, laneId):")
|
||||
for pc_line in pc.split('\n'):
|
||||
lines.append(f" # {pc_line}")
|
||||
# Only create Reg objects for registers actually used in the pseudocode
|
||||
# Add EXEC_LO/EXEC_HI if needed
|
||||
combined = code + pc
|
||||
regs = [('S0', 'Reg(s0)'), ('S1', 'Reg(s1)'), ('S2', 'Reg(s2)'),
|
||||
('D0', 'Reg(s0)' if is_div_scale else 'Reg(d0)'), ('D1', 'Reg(0)'),
|
||||
('SCC', 'Reg(scc)'), ('VCC', 'Reg(vcc)'), ('EXEC', 'Reg(exec_mask)'),
|
||||
('tmp', 'Reg(0)'), ('saveexec', 'Reg(exec_mask)'), ('laneId', 'lane'),
|
||||
('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)'),
|
||||
('PC', 'Reg(pc)')] # PC is passed in as byte address
|
||||
used = {name for name, _ in regs if name in combined}
|
||||
# EXEC_LO/EXEC_HI need EXEC
|
||||
if 'EXEC_LO' in combined or 'EXEC_HI' in combined: used.add('EXEC')
|
||||
# VCCZ/EXECZ need VCC/EXEC
|
||||
if 'VCCZ' in combined: used.add('VCC')
|
||||
if 'EXECZ' in combined: used.add('EXEC')
|
||||
for name, init in regs:
|
||||
if name in used: lines.append(f" {name} = {init}")
|
||||
if 'EXEC_LO' in combined: lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
|
||||
if 'EXEC_HI' in combined: lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
|
||||
# VCCZ = 1 if VCC == 0, EXECZ = 1 if EXEC == 0
|
||||
if 'VCCZ' in combined: lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)")
|
||||
if 'EXECZ' in combined: lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)")
|
||||
# Add compiled pseudocode with markers
|
||||
lines.append(" # --- compiled pseudocode ---")
|
||||
for line in code.split('\n'):
|
||||
lines.append(f" {line}")
|
||||
lines.append(" # --- end pseudocode ---")
|
||||
# Generate result dict - use raw params if Reg wasn't created
|
||||
d0_val = "D0._val" if 'D0' in used else "d0"
|
||||
scc_val = "SCC._val & 1" if 'SCC' in used else "scc & 1"
|
||||
lines.append(f" result = {{'d0': {d0_val}, 'scc': {scc_val}}}")
|
||||
if has_sdst:
|
||||
lines.append(" result['vcc_lane'] = (VCC._val >> lane) & 1")
|
||||
elif 'VCC' in used:
|
||||
lines.append(" if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1")
|
||||
if is_cmpx:
|
||||
lines.append(" result['exec_lane'] = (EXEC._val >> lane) & 1")
|
||||
elif 'EXEC' in used:
|
||||
lines.append(" if EXEC._val != exec_mask: result['exec'] = EXEC._val")
|
||||
if is_cmp:
|
||||
lines.append(" result['vcc_lane'] = (D0._val >> lane) & 1")
|
||||
if is_64:
|
||||
lines.append(" result['d0_64'] = True")
|
||||
if has_d1:
|
||||
lines.append(" result['d1'] = D1._val & 1")
|
||||
if has_pc:
|
||||
# Return new PC as absolute byte address, emulator will compute delta
|
||||
# Handle negative values (backward jumps): PC._val is stored as unsigned, convert to signed
|
||||
lines.append(" _pc = PC._val if PC._val < 0x8000000000000000 else PC._val - 0x10000000000000000")
|
||||
lines.append(" result['new_pc'] = _pc # absolute byte address")
|
||||
lines.append(" return result")
|
||||
code_lines = [line for line in code.split('\n') if line.strip()]
|
||||
if code_lines:
|
||||
for line in code_lines:
|
||||
lines.append(f" {line}")
|
||||
else:
|
||||
lines.append(" pass")
|
||||
lines.append("")
|
||||
|
||||
fn_entries.append((op, fn_name))
|
||||
|
|
@ -1181,9 +1024,8 @@ from extra.assembly.amd.pcode import *
|
|||
if 'VOP3Op' in enum_names:
|
||||
lines.append('''
|
||||
# V_WRITELANE_B32: Write scalar to specific lane's VGPR (not in PDF pseudocode)
|
||||
def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
wr_lane = s1 & 0x1f # lane select (5 bits for wave32)
|
||||
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
|
||||
def _VOP3Op_V_WRITELANE_B32(S0, S1, S2, D0, D1, SCC, VCC, EXEC, tmp, laneId):
|
||||
return (int(S1) & 0x1f, int(S0) & 0xffffffff) # (wr_lane, value)
|
||||
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
|
||||
''')
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ import unittest
|
|||
from extra.assembly.amd.pcode import (Reg, TypedView, SliceProxy, ExecContext, compile_pseudocode, _expr, MASK32, MASK64,
|
||||
_f32, _i32, _f16, _i16, f32_to_f16, _isnan, _bf16, _ibf16, bf16_to_f32, f32_to_bf16,
|
||||
BYTE_PERMUTE, v_sad_u8, v_msad_u8)
|
||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import _VOP3SDOp_V_DIV_SCALE_F32, _VOPCOp_V_CMP_CLASS_F32
|
||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import VOP3SDOp_FUNCTIONS, VOPCOp_FUNCTIONS
|
||||
from extra.assembly.amd.autogen.rdna3 import VOP3SDOp, VOPCOp
|
||||
from extra.assembly.amd.emu import _run_pcode
|
||||
|
||||
class TestReg(unittest.TestCase):
|
||||
def test_u32_read(self):
|
||||
|
|
@ -234,7 +236,8 @@ class TestPseudocodeRegressions(unittest.TestCase):
|
|||
s0 = 0x3f800000 # 1.0
|
||||
s1 = 0x40400000 # 3.0
|
||||
s2 = 0x3f800000 # 1.0 (numerator)
|
||||
result = _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
fn = VOP3SDOp_FUNCTIONS[VOP3SDOp.V_DIV_SCALE_F32]
|
||||
result = _run_pcode(fn, VOP3SDOp, VOP3SDOp.V_DIV_SCALE_F32, s0, s1, s2, 0, 0, 0, 0, 0xffffffff, 0)
|
||||
# Must always have vcc_lane in result
|
||||
self.assertIn('vcc_lane', result, "V_DIV_SCALE_F32 must always return vcc_lane")
|
||||
self.assertEqual(result['vcc_lane'], 0, "vcc_lane should be 0 when no scaling needed")
|
||||
|
|
@ -244,19 +247,20 @@ class TestPseudocodeRegressions(unittest.TestCase):
|
|||
Bug: isQuietNAN and isSignalNAN both used math.isnan which can't distinguish them."""
|
||||
quiet_nan = 0x7fc00000 # quiet NaN: exponent=255, bit22=1
|
||||
signal_nan = 0x7f800001 # signaling NaN: exponent=255, bit22=0
|
||||
fn = VOPCOp_FUNCTIONS[VOPCOp.V_CMP_CLASS_F32]
|
||||
# Test quiet NaN detection (bit 1 in mask)
|
||||
s1_quiet = 0b0000000010 # bit 1 = quiet NaN
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
result = _run_pcode(fn, VOPCOp, VOPCOp.V_CMP_CLASS_F32, quiet_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0)
|
||||
self.assertEqual(result['vcc_lane'], 1, "Should detect quiet NaN with quiet NaN mask")
|
||||
# Test signaling NaN detection (bit 0 in mask)
|
||||
s1_signal = 0b0000000001 # bit 0 = signaling NaN
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
result = _run_pcode(fn, VOPCOp, VOPCOp.V_CMP_CLASS_F32, signal_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0)
|
||||
self.assertEqual(result['vcc_lane'], 1, "Should detect signaling NaN with signaling NaN mask")
|
||||
# Test that quiet NaN doesn't match signaling NaN mask
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
result = _run_pcode(fn, VOPCOp, VOPCOp.V_CMP_CLASS_F32, quiet_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0)
|
||||
self.assertEqual(result['vcc_lane'], 0, "Quiet NaN should not match signaling NaN mask")
|
||||
# Test that signaling NaN doesn't match quiet NaN mask
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
result = _run_pcode(fn, VOPCOp, VOPCOp.V_CMP_CLASS_F32, signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0)
|
||||
self.assertEqual(result['vcc_lane'], 0, "Signaling NaN should not match quiet NaN mask")
|
||||
|
||||
def test_isnan_with_typed_view(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue