assembly/amd: add pcode ds ops

This commit is contained in:
George Hotz 2025-12-31 16:59:02 -05:00
commit b596f77e33
7 changed files with 4750 additions and 63 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -3,7 +3,7 @@
from __future__ import annotations
import ctypes
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.pcode import Reg
from extra.assembly.amd.pcode import Reg, LDSMem
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
@ -51,11 +51,6 @@ _D16_LOAD_MAP = {'LOAD_D16_U8': (1,0,0), 'LOAD_D16_I8': (1,1,0), 'LOAD_D16_B16':
_D16_STORE_MAP = {'STORE_D16_HI_B8': (1,1), 'STORE_D16_HI_B16': (2,1)} # (size, hi)
FLAT_D16_LOAD = _mem_ops([GLOBALOp, FLATOp], _D16_LOAD_MAP)
FLAT_D16_STORE = _mem_ops([GLOBALOp, FLATOp], _D16_STORE_MAP)
DS_LOAD = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0), DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)}
DS_STORE = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)}
# 2ADDR ops: load/store two values using offset0 and offset1
DS_LOAD_2ADDR = {DSOp.DS_LOAD_2ADDR_B32: 4, DSOp.DS_LOAD_2ADDR_B64: 8}
DS_STORE_2ADDR = {DSOp.DS_STORE_2ADDR_B32: 4, DSOp.DS_STORE_2ADDR_B64: 8}
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
@ -225,31 +220,32 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
return
if isinstance(inst, DS):
op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst
if op in DS_LOAD:
cnt, sz, sign = DS_LOAD[op]
for i in range(cnt): val = int.from_bytes(lds[addr0+i*sz:addr0+i*sz+sz], 'little'); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
elif op in DS_STORE:
cnt, sz = DS_STORE[op]
for i in range(cnt): lds[addr0+i*sz:addr0+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little')
elif op in DS_LOAD_2ADDR:
# Load two values from addr+offset0*sz and addr+offset1*sz into vdst (B32: 1 dword each, B64: 2 dwords each)
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
sz = DS_LOAD_2ADDR[op]
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
cnt = sz // 4 # 1 for B32, 2 for B64
for i in range(cnt): V[vdst + i] = int.from_bytes(lds[addr0+i*4:addr0+i*4+4], 'little')
for i in range(cnt): V[vdst + cnt + i] = int.from_bytes(lds[addr1+i*4:addr1+i*4+4], 'little')
elif op in DS_STORE_2ADDR:
# Store two values from data0 and data1 to addr+offset0*sz and addr+offset1*sz
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
sz = DS_STORE_2ADDR[op]
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
cnt = sz // 4
for i in range(cnt): lds[addr0+i*4:addr0+i*4+4] = (V[inst.data0 + i] & MASK32).to_bytes(4, 'little')
for i in range(cnt): lds[addr1+i*4:addr1+i*4+4] = (V[inst.data1 + i] & MASK32).to_bytes(4, 'little')
op, vdst = inst.op, inst.vdst
if DSOp in compiled and op in compiled[DSOp]:
fn = compiled[DSOp][op]
# For B64 operations, DATA/DATA2 access bits [63:32] so we need to pass 64-bit values
op_name = op.name
is_b64 = '_B64' in op_name or '_U64' in op_name or '_I64' in op_name or '_F64' in op_name
is_b128 = '_B128' in op_name
if is_b128:
s1 = Reg((V[inst.data0 + 3] << 96) | (V[inst.data0 + 2] << 64) | (V[inst.data0 + 1] << 32) | V[inst.data0])
s2 = Reg((V[inst.data1 + 3] << 96) | (V[inst.data1 + 2] << 64) | (V[inst.data1 + 1] << 32) | V[inst.data1]) if inst.data1 else Reg(0)
elif is_b64:
s1 = Reg((V[inst.data0 + 1] << 32) | V[inst.data0])
s2 = Reg((V[inst.data1 + 1] << 32) | V[inst.data1]) if inst.data1 else Reg(0)
else:
s1, s2 = Reg(V[inst.data0]), Reg(V[inst.data1] if inst.data1 else 0)
s0, mem = Reg(V[inst.addr]), LDSMem(lds)
result = fn(s0, s1, s2, Reg(V[vdst]), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None,
offset0=inst.offset0, offset1=inst.offset1, MEM=mem)
if 'D0' in result: V[vdst] = result['D0']._val & MASK32
if 'RETURN_DATA' in result:
# RETURN_DATA can be multi-dword (up to 128 bits for B128 ops)
val = result['RETURN_DATA']._val
V[vdst] = val & MASK32
V[vdst + 1] = (val >> 32) & MASK32
V[vdst + 2] = (val >> 64) & MASK32
V[vdst + 3] = (val >> 96) & MASK32
else: raise NotImplementedError(f"DS op {op}")
return

View file

@ -212,7 +212,7 @@ def signext_from_bit(val, bit):
__all__ = [
# Classes
'Reg', 'SliceProxy', 'TypedView',
'Reg', 'SliceProxy', 'TypedView', 'LDSMem',
# Pack functions
'_pack', '_pack32', 'pack', 'pack32',
# Constants
@ -341,11 +341,42 @@ class _Denorm:
f64 = _DenormChecker(64)
DENORM = _Denorm()
def _brev(v, bits):
"""Bit-reverse a value."""
result = 0
for i in range(bits): result |= ((v >> i) & 1) << (bits - 1 - i)
return result
# ═══════════════════════════════════════════════════════════════════════════════
# LDS MEMORY ACCESS (for DS instructions)
# ═══════════════════════════════════════════════════════════════════════════════
class _LDSAccessor:
"""Accessor for LDS memory at a specific address. Supports .u32/.f32 etc."""
__slots__ = ('_lds', '_addr')
def __init__(self, lds: bytearray, addr: int): self._lds, self._addr = lds, addr & 0xffff
def _read(self, size: int) -> int:
return int.from_bytes(self._lds[self._addr:self._addr+size], 'little') if self._addr + size <= len(self._lds) else 0
def _write(self, size: int, val: int):
if self._addr + size <= len(self._lds): self._lds[self._addr:self._addr+size] = (int(val) & ((1 << (size*8)) - 1)).to_bytes(size, 'little')
# Unsigned integer access
u8 = property(lambda s: s._read(1), lambda s, v: s._write(1, v))
u16 = property(lambda s: s._read(2), lambda s, v: s._write(2, v))
u32 = property(lambda s: s._read(4), lambda s, v: s._write(4, v))
u64 = property(lambda s: s._read(8), lambda s, v: s._write(8, v))
# Signed integer access
i8 = property(lambda s: _sext(s._read(1), 8), lambda s, v: s._write(1, v))
i16 = property(lambda s: _sext(s._read(2), 16), lambda s, v: s._write(2, v))
i32 = property(lambda s: _sext(s._read(4), 32), lambda s, v: s._write(4, v))
i64 = property(lambda s: _sext(s._read(8), 64), lambda s, v: s._write(8, v))
# Float access
f16 = property(lambda s: _f16(s._read(2)), lambda s, v: s._write(2, v if isinstance(v, int) else _i16(float(v))))
f32 = property(lambda s: _f32(s._read(4)), lambda s, v: s._write(4, _i32(float(v))))
f64 = property(lambda s: _f64(s._read(8)), lambda s, v: s._write(8, _i64(float(v))))
# Bit/byte aliases
b8, b16, b32, b64 = u8, u16, u32, u64
class LDSMem:
"""LDS memory wrapper that supports MEM[addr].u32 style access."""
__slots__ = ('_lds',)
def __init__(self, lds: bytearray): self._lds = lds
def __getitem__(self, addr) -> _LDSAccessor: return _LDSAccessor(self._lds, int(addr))
class SliceProxy:
"""Proxy for D0[31:16] that supports .f16/.u16 etc getters and setters."""
@ -473,10 +504,12 @@ class TypedView:
@property
def u32(s): return s if s._bits == 32 and not s._signed else int(s) & MASK32
MASK128 = (1 << 128) - 1
class Reg:
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works."""
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works. Supports up to 128 bits for DS_LOAD_B128."""
__slots__ = ('_val',)
def __init__(self, val=0): self._val = int(val) & MASK64
def __init__(self, val=0): self._val = int(val) & MASK128
# Typed views
u64 = property(lambda s: TypedView(s, 64), lambda s, v: setattr(s, '_val', int(v) & MASK64))

View file

@ -36,7 +36,7 @@ FIELD_ORDER = {
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
# Patterns that can't be handled by the DSL (require special handling in emu.py)
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
@ -46,7 +46,7 @@ UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST',
'BARRIER_STATE', 'ReallocVgprs',
'GPR_IDX', 'VSKIP', 'specified in', 'TTBL',
'fp6', 'bf6'] # Malformed pseudocode from PDF
'fp6', 'bf6', 'GS_REGS', 'M0.base', 'DS_DATA', '= 0..', 'sign(src', 'if no LDS', 'gds_base', 'vector mask'] # Malformed pseudocode from PDF
# ═══════════════════════════════════════════════════════════════════════════════
# COMPILER: pseudocode -> Python (minimal transforms)
@ -68,8 +68,8 @@ def compile_pseudocode(pseudocode: str) -> str:
lines = []
indent, need_pass, in_first_match_loop = 0, False, False
for line in joined_lines:
line = line.strip()
if not line or line.startswith('//'): continue
line = line.split('//')[0].strip() # Strip C-style comments
if not line: continue
if line.startswith('if '):
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
indent += 1
@ -351,8 +351,9 @@ def _extract_pseudocode(text: str) -> str | None:
for line in lines:
s = line.strip()
if not s or re.match(r'^\d+ of \d+$', s) or re.match(r'^\d+\.\d+\..*Instructions', s): continue
if s.startswith(('Notes', 'Functional examples')): break
if s.startswith(('Notes', 'Functional examples', '', '-')): break # Stop at notes/bullets
if s.startswith(('"RDNA', 'AMD ', 'CDNA')): continue
if '' in s or '' in s: continue # Skip lines with bullets/dashes
if '= lambda(' in s: in_lambda += 1; continue
if in_lambda > 0:
if s.endswith(');'): in_lambda -= 1
@ -362,7 +363,7 @@ def _extract_pseudocode(text: str) -> str | None:
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
is_code = (any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =',
'D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
'D0[', 'D1[', 'S0[', 'S1[', 'S2[', 'MEM[', 'RETURN_DATA']) or
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s))
if is_code: result.append(s)
@ -448,13 +449,13 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
# Get op enums for this arch (import from .ins which re-exports from .enum)
import importlib
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp'] if hasattr(autogen, name)]
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp', 'DSOp'] if hasattr(autogen, name)]
# Build defined ops mapping
defined_ops: dict[tuple, list] = {}
for enum_cls in OP_ENUMS:
for op in enum_cls:
if op.name.startswith(('S_', 'V_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
if op.name.startswith(('S_', 'V_', 'DS_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
enum_names = [e.__name__ for e in OP_ENUMS]
lines = [f'''# autogenerated by pdf.py - do not edit
@ -541,11 +542,16 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
is_div_scale = 'DIV_SCALE' in op.name
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
is_ds = cls_name == 'DSOp'
combined = code + pc
fn_name = f"_{cls_name}_{op.name}"
# Function accepts Reg objects directly (uppercase names), laneId is passed directly as int
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"]
# DSOp functions get additional MEM and offset parameters
if is_ds:
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None, MEM=None, offset0=0, offset1=0):"]
else:
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"]
# Registers that need special handling (not passed directly)
# Only init if used but not first assigned as `name = Reg(...)` in the compiled code
@ -554,6 +560,11 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
# DS ops: add ADDR, DATA, OFFSET, RETURN_DATA variables
if is_ds:
special_regs.extend([('ADDR', 'S0._val'), ('DATA', 'S1'), ('DATA0', 'S1'), ('DATA1', 'S2'), ('DATA2', 'S2'),
('OFFSET', 'Reg(offset0)'), ('OFFSET0', 'Reg(offset0)'), ('OFFSET1', 'Reg(offset1)'),
('RETURN_DATA', 'Reg(0)')])
used = {name for name, _ in special_regs if name in combined}
# Detect which registers are modified (not just read) - look for assignments
@ -562,6 +573,8 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
# DS ops: detect memory writes (MEM[...] = ...)
modifies_mem = is_ds and bool(re.search(r'MEM\[.*\]\.[a-z0-9]+\s*=', combined))
# Build init code for special registers
init_lines = []
@ -587,6 +600,9 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
if modifies_exec: result_items.append("'EXEC': EXEC")
if has_d1: result_items.append("'D1': D1")
if modifies_pc: result_items.append("'PC': PC")
# DS ops: return RETURN_DATA if it was set (matches RETURN_DATA.u32 = or RETURN_DATA[...] =)
if is_ds and 'RETURN_DATA' in combined and re.search(r'RETURN_DATA[\.\[].*=', combined):
result_items.append("'RETURN_DATA': RETURN_DATA")
lines.append(f" return {{{', '.join(result_items)}}}\n")
return fn_name, '\n'.join(lines)

View file

@ -4256,3 +4256,217 @@ class TestDS2Addr(unittest.TestCase):
# v6,v7 from addr 8-15: 0x33333333, 0x44444444
self.assertEqual(st.vgpr[0][6], 0x33333333, "v6 should be 0x33333333")
self.assertEqual(st.vgpr[0][7], 0x44444444, "v7 should be 0x44444444")
class TestDSAtomic(unittest.TestCase):
"""Tests for DS atomic instructions (add, max, min, and, or, xor, cmpstore, etc.)."""
def test_ds_max_rtn_u32(self):
"""DS_MAX_RTN_U32: atomically store max(mem, data) and return old value."""
instructions = [
v_mov_b32_e32(v[10], 0), # addr = 0
s_mov_b32(s[2], 100),
v_mov_b32_e32(v[0], s[2]), # initial value = 100
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 200),
v_mov_b32_e32(v[1], s[2]), # data = 200 (greater than 100)
ds_max_rtn_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0), # read result
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 100, "v2 should have old value (100)")
self.assertEqual(st.vgpr[0][3], 200, "v3 should have max(100, 200) = 200")
def test_ds_max_u32_no_rtn(self):
"""DS_MAX_U32 (no RTN): atomically store max, no return value."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 100),
v_mov_b32_e32(v[0], s[2]), # initial = 100
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 200),
v_mov_b32_e32(v[1], s[2]), # data = 200
ds_max_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][3], 200, "v3 should have max(100, 200) = 200")
def test_ds_min_rtn_u32(self):
"""DS_MIN_RTN_U32: atomically store min(mem, data) and return old value."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 200),
v_mov_b32_e32(v[0], s[2]), # initial = 200
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 100),
v_mov_b32_e32(v[1], s[2]), # data = 100
ds_min_rtn_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 200, "v2 should have old value (200)")
self.assertEqual(st.vgpr[0][3], 100, "v3 should have min(200, 100) = 100")
def test_ds_and_rtn_b32(self):
"""DS_AND_RTN_B32: atomically AND mem with data and return old value."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 0xFF00FF00),
v_mov_b32_e32(v[0], s[2]), # initial = 0xFF00FF00
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 0xFFFF0000),
v_mov_b32_e32(v[1], s[2]), # data = 0xFFFF0000
ds_and_rtn_b32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0xFF00FF00, "v2 should have old value")
self.assertEqual(st.vgpr[0][3], 0xFF000000, "v3 should have 0xFF00FF00 & 0xFFFF0000 = 0xFF000000")
def test_ds_or_rtn_b32(self):
"""DS_OR_RTN_B32: atomically OR mem with data and return old value."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 0x00FF0000),
v_mov_b32_e32(v[0], s[2]), # initial = 0x00FF0000
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 0x000000FF),
v_mov_b32_e32(v[1], s[2]), # data = 0x000000FF
ds_or_rtn_b32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0x00FF0000, "v2 should have old value")
self.assertEqual(st.vgpr[0][3], 0x00FF00FF, "v3 should have 0x00FF0000 | 0x000000FF = 0x00FF00FF")
def test_ds_xor_rtn_b32(self):
"""DS_XOR_RTN_B32: atomically XOR mem with data and return old value."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 0xAAAAAAAA),
v_mov_b32_e32(v[0], s[2]), # initial = 0xAAAAAAAA
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 0xFFFFFFFF),
v_mov_b32_e32(v[1], s[2]), # data = 0xFFFFFFFF
ds_xor_rtn_b32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0xAAAAAAAA, "v2 should have old value")
self.assertEqual(st.vgpr[0][3], 0x55555555, "v3 should have 0xAAAAAAAA ^ 0xFFFFFFFF = 0x55555555")
def test_ds_cmpstore_b32_match(self):
"""DS_CMPSTORE_B32: conditional store when compare matches."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 100),
v_mov_b32_e32(v[0], s[2]), # initial = 100
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 200),
v_mov_b32_e32(v[1], s[2]), # new value = 200
s_mov_b32(s[2], 100),
v_mov_b32_e32(v[2], s[2]), # compare = 100 (matches current)
ds_cmpstore_b32(addr=v[10], data0=v[1], data1=v[2], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[4], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][4], 200, "mem should be updated to 200 (compare matched)")
def test_ds_cmpstore_b32_no_match(self):
"""DS_CMPSTORE_B32: no store when compare doesn't match."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 100),
v_mov_b32_e32(v[0], s[2]), # initial = 100
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 200),
v_mov_b32_e32(v[1], s[2]), # new value = 200
s_mov_b32(s[2], 50),
v_mov_b32_e32(v[2], s[2]), # compare = 50 (doesn't match 100)
ds_cmpstore_b32(addr=v[10], data0=v[1], data1=v[2], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[4], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][4], 100, "mem should still be 100 (compare didn't match)")
def test_ds_inc_u32(self):
"""DS_INC_U32: increment with wrap."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 5),
v_mov_b32_e32(v[0], s[2]), # initial = 5
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 10),
v_mov_b32_e32(v[1], s[2]), # limit = 10
ds_inc_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 5, "v2 should have old value (5)")
self.assertEqual(st.vgpr[0][3], 6, "v3 should have incremented value (6)")
def test_ds_dec_u32(self):
"""DS_DEC_U32: decrement with wrap."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 5),
v_mov_b32_e32(v[0], s[2]), # initial = 5
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 10),
v_mov_b32_e32(v[1], s[2]), # limit = 10
ds_dec_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 5, "v2 should have old value (5)")
self.assertEqual(st.vgpr[0][3], 4, "v3 should have decremented value (4)")
def test_ds_dec_u32_wrap(self):
"""DS_DEC_U32: wraps to limit when value is 0 or > limit."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 0),
v_mov_b32_e32(v[0], s[2]), # initial = 0
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 10),
v_mov_b32_e32(v[1], s[2]), # limit = 10
ds_dec_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0, "v2 should have old value (0)")
self.assertEqual(st.vgpr[0][3], 10, "v3 should wrap to limit (10)")