mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
assembly/amd: add pcode ds ops
This commit is contained in:
parent
2bb07d4824
commit
b596f77e33
7 changed files with 4750 additions and 63 deletions
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -3,7 +3,7 @@
|
|||
from __future__ import annotations
|
||||
import ctypes
|
||||
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
from extra.assembly.amd.pcode import Reg
|
||||
from extra.assembly.amd.pcode import Reg, LDSMem
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
|
||||
|
|
@ -51,11 +51,6 @@ _D16_LOAD_MAP = {'LOAD_D16_U8': (1,0,0), 'LOAD_D16_I8': (1,1,0), 'LOAD_D16_B16':
|
|||
_D16_STORE_MAP = {'STORE_D16_HI_B8': (1,1), 'STORE_D16_HI_B16': (2,1)} # (size, hi)
|
||||
FLAT_D16_LOAD = _mem_ops([GLOBALOp, FLATOp], _D16_LOAD_MAP)
|
||||
FLAT_D16_STORE = _mem_ops([GLOBALOp, FLATOp], _D16_STORE_MAP)
|
||||
DS_LOAD = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0), DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)}
|
||||
DS_STORE = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)}
|
||||
# 2ADDR ops: load/store two values using offset0 and offset1
|
||||
DS_LOAD_2ADDR = {DSOp.DS_LOAD_2ADDR_B32: 4, DSOp.DS_LOAD_2ADDR_B64: 8}
|
||||
DS_STORE_2ADDR = {DSOp.DS_STORE_2ADDR_B32: 4, DSOp.DS_STORE_2ADDR_B64: 8}
|
||||
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
|
||||
|
||||
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
|
||||
|
|
@ -225,31 +220,32 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
|||
return
|
||||
|
||||
if isinstance(inst, DS):
|
||||
op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst
|
||||
if op in DS_LOAD:
|
||||
cnt, sz, sign = DS_LOAD[op]
|
||||
for i in range(cnt): val = int.from_bytes(lds[addr0+i*sz:addr0+i*sz+sz], 'little'); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
|
||||
elif op in DS_STORE:
|
||||
cnt, sz = DS_STORE[op]
|
||||
for i in range(cnt): lds[addr0+i*sz:addr0+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little')
|
||||
elif op in DS_LOAD_2ADDR:
|
||||
# Load two values from addr+offset0*sz and addr+offset1*sz into vdst (B32: 1 dword each, B64: 2 dwords each)
|
||||
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
|
||||
sz = DS_LOAD_2ADDR[op]
|
||||
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
|
||||
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
|
||||
cnt = sz // 4 # 1 for B32, 2 for B64
|
||||
for i in range(cnt): V[vdst + i] = int.from_bytes(lds[addr0+i*4:addr0+i*4+4], 'little')
|
||||
for i in range(cnt): V[vdst + cnt + i] = int.from_bytes(lds[addr1+i*4:addr1+i*4+4], 'little')
|
||||
elif op in DS_STORE_2ADDR:
|
||||
# Store two values from data0 and data1 to addr+offset0*sz and addr+offset1*sz
|
||||
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
|
||||
sz = DS_STORE_2ADDR[op]
|
||||
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
|
||||
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
|
||||
cnt = sz // 4
|
||||
for i in range(cnt): lds[addr0+i*4:addr0+i*4+4] = (V[inst.data0 + i] & MASK32).to_bytes(4, 'little')
|
||||
for i in range(cnt): lds[addr1+i*4:addr1+i*4+4] = (V[inst.data1 + i] & MASK32).to_bytes(4, 'little')
|
||||
op, vdst = inst.op, inst.vdst
|
||||
if DSOp in compiled and op in compiled[DSOp]:
|
||||
fn = compiled[DSOp][op]
|
||||
# For B64 operations, DATA/DATA2 access bits [63:32] so we need to pass 64-bit values
|
||||
op_name = op.name
|
||||
is_b64 = '_B64' in op_name or '_U64' in op_name or '_I64' in op_name or '_F64' in op_name
|
||||
is_b128 = '_B128' in op_name
|
||||
if is_b128:
|
||||
s1 = Reg((V[inst.data0 + 3] << 96) | (V[inst.data0 + 2] << 64) | (V[inst.data0 + 1] << 32) | V[inst.data0])
|
||||
s2 = Reg((V[inst.data1 + 3] << 96) | (V[inst.data1 + 2] << 64) | (V[inst.data1 + 1] << 32) | V[inst.data1]) if inst.data1 else Reg(0)
|
||||
elif is_b64:
|
||||
s1 = Reg((V[inst.data0 + 1] << 32) | V[inst.data0])
|
||||
s2 = Reg((V[inst.data1 + 1] << 32) | V[inst.data1]) if inst.data1 else Reg(0)
|
||||
else:
|
||||
s1, s2 = Reg(V[inst.data0]), Reg(V[inst.data1] if inst.data1 else 0)
|
||||
s0, mem = Reg(V[inst.addr]), LDSMem(lds)
|
||||
result = fn(s0, s1, s2, Reg(V[vdst]), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None,
|
||||
offset0=inst.offset0, offset1=inst.offset1, MEM=mem)
|
||||
if 'D0' in result: V[vdst] = result['D0']._val & MASK32
|
||||
if 'RETURN_DATA' in result:
|
||||
# RETURN_DATA can be multi-dword (up to 128 bits for B128 ops)
|
||||
val = result['RETURN_DATA']._val
|
||||
V[vdst] = val & MASK32
|
||||
V[vdst + 1] = (val >> 32) & MASK32
|
||||
V[vdst + 2] = (val >> 64) & MASK32
|
||||
V[vdst + 3] = (val >> 96) & MASK32
|
||||
else: raise NotImplementedError(f"DS op {op}")
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ def signext_from_bit(val, bit):
|
|||
|
||||
__all__ = [
|
||||
# Classes
|
||||
'Reg', 'SliceProxy', 'TypedView',
|
||||
'Reg', 'SliceProxy', 'TypedView', 'LDSMem',
|
||||
# Pack functions
|
||||
'_pack', '_pack32', 'pack', 'pack32',
|
||||
# Constants
|
||||
|
|
@ -341,11 +341,42 @@ class _Denorm:
|
|||
f64 = _DenormChecker(64)
|
||||
DENORM = _Denorm()
|
||||
|
||||
def _brev(v, bits):
|
||||
"""Bit-reverse a value."""
|
||||
result = 0
|
||||
for i in range(bits): result |= ((v >> i) & 1) << (bits - 1 - i)
|
||||
return result
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# LDS MEMORY ACCESS (for DS instructions)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class _LDSAccessor:
|
||||
"""Accessor for LDS memory at a specific address. Supports .u32/.f32 etc."""
|
||||
__slots__ = ('_lds', '_addr')
|
||||
def __init__(self, lds: bytearray, addr: int): self._lds, self._addr = lds, addr & 0xffff
|
||||
|
||||
def _read(self, size: int) -> int:
|
||||
return int.from_bytes(self._lds[self._addr:self._addr+size], 'little') if self._addr + size <= len(self._lds) else 0
|
||||
def _write(self, size: int, val: int):
|
||||
if self._addr + size <= len(self._lds): self._lds[self._addr:self._addr+size] = (int(val) & ((1 << (size*8)) - 1)).to_bytes(size, 'little')
|
||||
|
||||
# Unsigned integer access
|
||||
u8 = property(lambda s: s._read(1), lambda s, v: s._write(1, v))
|
||||
u16 = property(lambda s: s._read(2), lambda s, v: s._write(2, v))
|
||||
u32 = property(lambda s: s._read(4), lambda s, v: s._write(4, v))
|
||||
u64 = property(lambda s: s._read(8), lambda s, v: s._write(8, v))
|
||||
# Signed integer access
|
||||
i8 = property(lambda s: _sext(s._read(1), 8), lambda s, v: s._write(1, v))
|
||||
i16 = property(lambda s: _sext(s._read(2), 16), lambda s, v: s._write(2, v))
|
||||
i32 = property(lambda s: _sext(s._read(4), 32), lambda s, v: s._write(4, v))
|
||||
i64 = property(lambda s: _sext(s._read(8), 64), lambda s, v: s._write(8, v))
|
||||
# Float access
|
||||
f16 = property(lambda s: _f16(s._read(2)), lambda s, v: s._write(2, v if isinstance(v, int) else _i16(float(v))))
|
||||
f32 = property(lambda s: _f32(s._read(4)), lambda s, v: s._write(4, _i32(float(v))))
|
||||
f64 = property(lambda s: _f64(s._read(8)), lambda s, v: s._write(8, _i64(float(v))))
|
||||
# Bit/byte aliases
|
||||
b8, b16, b32, b64 = u8, u16, u32, u64
|
||||
|
||||
class LDSMem:
|
||||
"""LDS memory wrapper that supports MEM[addr].u32 style access."""
|
||||
__slots__ = ('_lds',)
|
||||
def __init__(self, lds: bytearray): self._lds = lds
|
||||
def __getitem__(self, addr) -> _LDSAccessor: return _LDSAccessor(self._lds, int(addr))
|
||||
|
||||
class SliceProxy:
|
||||
"""Proxy for D0[31:16] that supports .f16/.u16 etc getters and setters."""
|
||||
|
|
@ -473,10 +504,12 @@ class TypedView:
|
|||
@property
|
||||
def u32(s): return s if s._bits == 32 and not s._signed else int(s) & MASK32
|
||||
|
||||
MASK128 = (1 << 128) - 1
|
||||
|
||||
class Reg:
|
||||
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works."""
|
||||
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works. Supports up to 128 bits for DS_LOAD_B128."""
|
||||
__slots__ = ('_val',)
|
||||
def __init__(self, val=0): self._val = int(val) & MASK64
|
||||
def __init__(self, val=0): self._val = int(val) & MASK128
|
||||
|
||||
# Typed views
|
||||
u64 = property(lambda s: TypedView(s, 64), lambda s, v: setattr(s, '_val', int(v) & MASK64))
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ FIELD_ORDER = {
|
|||
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
|
||||
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
|
||||
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
|
||||
INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||
INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||
|
||||
# Patterns that can't be handled by the DSL (require special handling in emu.py)
|
||||
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
|
||||
|
|
@ -46,7 +46,7 @@ UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
|
|||
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST',
|
||||
'BARRIER_STATE', 'ReallocVgprs',
|
||||
'GPR_IDX', 'VSKIP', 'specified in', 'TTBL',
|
||||
'fp6', 'bf6'] # Malformed pseudocode from PDF
|
||||
'fp6', 'bf6', 'GS_REGS', 'M0.base', 'DS_DATA', '= 0..', 'sign(src', 'if no LDS', 'gds_base', 'vector mask'] # Malformed pseudocode from PDF
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# COMPILER: pseudocode -> Python (minimal transforms)
|
||||
|
|
@ -68,8 +68,8 @@ def compile_pseudocode(pseudocode: str) -> str:
|
|||
lines = []
|
||||
indent, need_pass, in_first_match_loop = 0, False, False
|
||||
for line in joined_lines:
|
||||
line = line.strip()
|
||||
if not line or line.startswith('//'): continue
|
||||
line = line.split('//')[0].strip() # Strip C-style comments
|
||||
if not line: continue
|
||||
if line.startswith('if '):
|
||||
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
|
||||
indent += 1
|
||||
|
|
@ -351,8 +351,9 @@ def _extract_pseudocode(text: str) -> str | None:
|
|||
for line in lines:
|
||||
s = line.strip()
|
||||
if not s or re.match(r'^\d+ of \d+$', s) or re.match(r'^\d+\.\d+\..*Instructions', s): continue
|
||||
if s.startswith(('Notes', 'Functional examples')): break
|
||||
if s.startswith(('Notes', 'Functional examples', '•', '-')): break # Stop at notes/bullets
|
||||
if s.startswith(('"RDNA', 'AMD ', 'CDNA')): continue
|
||||
if '•' in s or '–' in s: continue # Skip lines with bullets/dashes
|
||||
if '= lambda(' in s: in_lambda += 1; continue
|
||||
if in_lambda > 0:
|
||||
if s.endswith(');'): in_lambda -= 1
|
||||
|
|
@ -362,7 +363,7 @@ def _extract_pseudocode(text: str) -> str | None:
|
|||
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
|
||||
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
|
||||
is_code = (any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =',
|
||||
'D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
|
||||
'D0[', 'D1[', 'S0[', 'S1[', 'S2[', 'MEM[', 'RETURN_DATA']) or
|
||||
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
|
||||
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s))
|
||||
if is_code: result.append(s)
|
||||
|
|
@ -448,13 +449,13 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
|
|||
# Get op enums for this arch (import from .ins which re-exports from .enum)
|
||||
import importlib
|
||||
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
|
||||
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp'] if hasattr(autogen, name)]
|
||||
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp', 'DSOp'] if hasattr(autogen, name)]
|
||||
|
||||
# Build defined ops mapping
|
||||
defined_ops: dict[tuple, list] = {}
|
||||
for enum_cls in OP_ENUMS:
|
||||
for op in enum_cls:
|
||||
if op.name.startswith(('S_', 'V_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
|
||||
if op.name.startswith(('S_', 'V_', 'DS_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
|
||||
|
||||
enum_names = [e.__name__ for e in OP_ENUMS]
|
||||
lines = [f'''# autogenerated by pdf.py - do not edit
|
||||
|
|
@ -541,11 +542,16 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
|
|||
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
|
||||
is_div_scale = 'DIV_SCALE' in op.name
|
||||
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
||||
is_ds = cls_name == 'DSOp'
|
||||
combined = code + pc
|
||||
|
||||
fn_name = f"_{cls_name}_{op.name}"
|
||||
# Function accepts Reg objects directly (uppercase names), laneId is passed directly as int
|
||||
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"]
|
||||
# DSOp functions get additional MEM and offset parameters
|
||||
if is_ds:
|
||||
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None, MEM=None, offset0=0, offset1=0):"]
|
||||
else:
|
||||
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"]
|
||||
|
||||
# Registers that need special handling (not passed directly)
|
||||
# Only init if used but not first assigned as `name = Reg(...)` in the compiled code
|
||||
|
|
@ -554,6 +560,11 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
|
|||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
|
||||
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
|
||||
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
|
||||
# DS ops: add ADDR, DATA, OFFSET, RETURN_DATA variables
|
||||
if is_ds:
|
||||
special_regs.extend([('ADDR', 'S0._val'), ('DATA', 'S1'), ('DATA0', 'S1'), ('DATA1', 'S2'), ('DATA2', 'S2'),
|
||||
('OFFSET', 'Reg(offset0)'), ('OFFSET0', 'Reg(offset0)'), ('OFFSET1', 'Reg(offset1)'),
|
||||
('RETURN_DATA', 'Reg(0)')])
|
||||
used = {name for name, _ in special_regs if name in combined}
|
||||
|
||||
# Detect which registers are modified (not just read) - look for assignments
|
||||
|
|
@ -562,6 +573,8 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
|
|||
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
|
||||
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
|
||||
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
|
||||
# DS ops: detect memory writes (MEM[...] = ...)
|
||||
modifies_mem = is_ds and bool(re.search(r'MEM\[.*\]\.[a-z0-9]+\s*=', combined))
|
||||
|
||||
# Build init code for special registers
|
||||
init_lines = []
|
||||
|
|
@ -587,6 +600,9 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
|
|||
if modifies_exec: result_items.append("'EXEC': EXEC")
|
||||
if has_d1: result_items.append("'D1': D1")
|
||||
if modifies_pc: result_items.append("'PC': PC")
|
||||
# DS ops: return RETURN_DATA if it was set (matches RETURN_DATA.u32 = or RETURN_DATA[...] =)
|
||||
if is_ds and 'RETURN_DATA' in combined and re.search(r'RETURN_DATA[\.\[].*=', combined):
|
||||
result_items.append("'RETURN_DATA': RETURN_DATA")
|
||||
lines.append(f" return {{{', '.join(result_items)}}}\n")
|
||||
return fn_name, '\n'.join(lines)
|
||||
|
||||
|
|
|
|||
|
|
@ -4256,3 +4256,217 @@ class TestDS2Addr(unittest.TestCase):
|
|||
# v6,v7 from addr 8-15: 0x33333333, 0x44444444
|
||||
self.assertEqual(st.vgpr[0][6], 0x33333333, "v6 should be 0x33333333")
|
||||
self.assertEqual(st.vgpr[0][7], 0x44444444, "v7 should be 0x44444444")
|
||||
|
||||
|
||||
class TestDSAtomic(unittest.TestCase):
|
||||
"""Tests for DS atomic instructions (add, max, min, and, or, xor, cmpstore, etc.)."""
|
||||
|
||||
def test_ds_max_rtn_u32(self):
|
||||
"""DS_MAX_RTN_U32: atomically store max(mem, data) and return old value."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0), # addr = 0
|
||||
s_mov_b32(s[2], 100),
|
||||
v_mov_b32_e32(v[0], s[2]), # initial value = 100
|
||||
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
s_mov_b32(s[2], 200),
|
||||
v_mov_b32_e32(v[1], s[2]), # data = 200 (greater than 100)
|
||||
ds_max_rtn_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b32(addr=v[10], vdst=v[3], offset0=0), # read result
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 100, "v2 should have old value (100)")
|
||||
self.assertEqual(st.vgpr[0][3], 200, "v3 should have max(100, 200) = 200")
|
||||
|
||||
def test_ds_max_u32_no_rtn(self):
|
||||
"""DS_MAX_U32 (no RTN): atomically store max, no return value."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[2], 100),
|
||||
v_mov_b32_e32(v[0], s[2]), # initial = 100
|
||||
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
s_mov_b32(s[2], 200),
|
||||
v_mov_b32_e32(v[1], s[2]), # data = 200
|
||||
ds_max_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][3], 200, "v3 should have max(100, 200) = 200")
|
||||
|
||||
def test_ds_min_rtn_u32(self):
|
||||
"""DS_MIN_RTN_U32: atomically store min(mem, data) and return old value."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[2], 200),
|
||||
v_mov_b32_e32(v[0], s[2]), # initial = 200
|
||||
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
s_mov_b32(s[2], 100),
|
||||
v_mov_b32_e32(v[1], s[2]), # data = 100
|
||||
ds_min_rtn_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 200, "v2 should have old value (200)")
|
||||
self.assertEqual(st.vgpr[0][3], 100, "v3 should have min(200, 100) = 100")
|
||||
|
||||
def test_ds_and_rtn_b32(self):
|
||||
"""DS_AND_RTN_B32: atomically AND mem with data and return old value."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[2], 0xFF00FF00),
|
||||
v_mov_b32_e32(v[0], s[2]), # initial = 0xFF00FF00
|
||||
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
s_mov_b32(s[2], 0xFFFF0000),
|
||||
v_mov_b32_e32(v[1], s[2]), # data = 0xFFFF0000
|
||||
ds_and_rtn_b32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0xFF00FF00, "v2 should have old value")
|
||||
self.assertEqual(st.vgpr[0][3], 0xFF000000, "v3 should have 0xFF00FF00 & 0xFFFF0000 = 0xFF000000")
|
||||
|
||||
def test_ds_or_rtn_b32(self):
|
||||
"""DS_OR_RTN_B32: atomically OR mem with data and return old value."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[2], 0x00FF0000),
|
||||
v_mov_b32_e32(v[0], s[2]), # initial = 0x00FF0000
|
||||
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
s_mov_b32(s[2], 0x000000FF),
|
||||
v_mov_b32_e32(v[1], s[2]), # data = 0x000000FF
|
||||
ds_or_rtn_b32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0x00FF0000, "v2 should have old value")
|
||||
self.assertEqual(st.vgpr[0][3], 0x00FF00FF, "v3 should have 0x00FF0000 | 0x000000FF = 0x00FF00FF")
|
||||
|
||||
def test_ds_xor_rtn_b32(self):
|
||||
"""DS_XOR_RTN_B32: atomically XOR mem with data and return old value."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[2], 0xAAAAAAAA),
|
||||
v_mov_b32_e32(v[0], s[2]), # initial = 0xAAAAAAAA
|
||||
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
s_mov_b32(s[2], 0xFFFFFFFF),
|
||||
v_mov_b32_e32(v[1], s[2]), # data = 0xFFFFFFFF
|
||||
ds_xor_rtn_b32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0xAAAAAAAA, "v2 should have old value")
|
||||
self.assertEqual(st.vgpr[0][3], 0x55555555, "v3 should have 0xAAAAAAAA ^ 0xFFFFFFFF = 0x55555555")
|
||||
|
||||
def test_ds_cmpstore_b32_match(self):
|
||||
"""DS_CMPSTORE_B32: conditional store when compare matches."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[2], 100),
|
||||
v_mov_b32_e32(v[0], s[2]), # initial = 100
|
||||
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
s_mov_b32(s[2], 200),
|
||||
v_mov_b32_e32(v[1], s[2]), # new value = 200
|
||||
s_mov_b32(s[2], 100),
|
||||
v_mov_b32_e32(v[2], s[2]), # compare = 100 (matches current)
|
||||
ds_cmpstore_b32(addr=v[10], data0=v[1], data1=v[2], vdst=v[3], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b32(addr=v[10], vdst=v[4], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][4], 200, "mem should be updated to 200 (compare matched)")
|
||||
|
||||
def test_ds_cmpstore_b32_no_match(self):
|
||||
"""DS_CMPSTORE_B32: no store when compare doesn't match."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[2], 100),
|
||||
v_mov_b32_e32(v[0], s[2]), # initial = 100
|
||||
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
s_mov_b32(s[2], 200),
|
||||
v_mov_b32_e32(v[1], s[2]), # new value = 200
|
||||
s_mov_b32(s[2], 50),
|
||||
v_mov_b32_e32(v[2], s[2]), # compare = 50 (doesn't match 100)
|
||||
ds_cmpstore_b32(addr=v[10], data0=v[1], data1=v[2], vdst=v[3], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b32(addr=v[10], vdst=v[4], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][4], 100, "mem should still be 100 (compare didn't match)")
|
||||
|
||||
def test_ds_inc_u32(self):
|
||||
"""DS_INC_U32: increment with wrap."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[2], 5),
|
||||
v_mov_b32_e32(v[0], s[2]), # initial = 5
|
||||
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
s_mov_b32(s[2], 10),
|
||||
v_mov_b32_e32(v[1], s[2]), # limit = 10
|
||||
ds_inc_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 5, "v2 should have old value (5)")
|
||||
self.assertEqual(st.vgpr[0][3], 6, "v3 should have incremented value (6)")
|
||||
|
||||
def test_ds_dec_u32(self):
|
||||
"""DS_DEC_U32: decrement with wrap."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[2], 5),
|
||||
v_mov_b32_e32(v[0], s[2]), # initial = 5
|
||||
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
s_mov_b32(s[2], 10),
|
||||
v_mov_b32_e32(v[1], s[2]), # limit = 10
|
||||
ds_dec_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 5, "v2 should have old value (5)")
|
||||
self.assertEqual(st.vgpr[0][3], 4, "v3 should have decremented value (4)")
|
||||
|
||||
def test_ds_dec_u32_wrap(self):
|
||||
"""DS_DEC_U32: wraps to limit when value is 0 or > limit."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[2], 0),
|
||||
v_mov_b32_e32(v[0], s[2]), # initial = 0
|
||||
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
s_mov_b32(s[2], 10),
|
||||
v_mov_b32_e32(v[1], s[2]), # limit = 10
|
||||
ds_dec_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0, "v2 should have old value (0)")
|
||||
self.assertEqual(st.vgpr[0][3], 10, "v3 should wrap to limit (10)")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue