Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
aec4d65241 ds compiled 2025-12-31 15:43:44 -05:00
George Hotz
f022a7d8a7 assembly/amd: move more instructions to pcode 2025-12-31 15:42:59 -05:00
4 changed files with 271 additions and 56 deletions

View file

@ -28,4 +28,4 @@ The ops tests also pass, but they are very slow, so you should run them one at a
When something is caught by main tinygrad tests, a local regression test should be added to `extra/assembly/amd/test`. While working with tinygrad, you can dump the assembly with `DEBUG=7`. These tests all pass on real hardware, so if a test is failing with `AMD=1 PYTHON_REMU=1 MOCKGPU=1` it's likely because an instruction is emulated incorrectly. You can test without `MOCKGPU=1` to test on real hardware, if it works on real hardware there's a bug in the emulator. When something is caught by main tinygrad tests, a local regression test should be added to `extra/assembly/amd/test`. While working with tinygrad, you can dump the assembly with `DEBUG=7`. These tests all pass on real hardware, so if a test is failing with `AMD=1 PYTHON_REMU=1 MOCKGPU=1` it's likely because an instruction is emulated incorrectly. You can test without `MOCKGPU=1` to test on real hardware, if it works on real hardware there's a bug in the emulator.
Currently, only RDNA3 is well supported, but when finished, this will support RDNA3+RDNA4+CDNA in ~2000 lines. Currently, only RDNA3 is well supported, but when finished, this will support RDNA3+RDNA4+CDNA in ~2000 lines. Count lines with `cloc --by-file extra/assembly/amd/*.py`

View file

@ -2,7 +2,7 @@
# to regenerate: python -m extra.assembly.amd.pdf --arch rdna3 # to regenerate: python -m extra.assembly.amd.pdf --arch rdna3
# ruff: noqa: E501,F405,F403 # ruff: noqa: E501,F405,F403
# mypy: ignore-errors # mypy: ignore-errors
from extra.assembly.amd.autogen.rdna3.enum import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp from extra.assembly.amd.autogen.rdna3.enum import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp
from extra.assembly.amd.pcode import * from extra.assembly.amd.pcode import *
def _SOP1Op_S_MOV_B32(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None): def _SOP1Op_S_MOV_B32(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):
@ -5203,6 +5203,102 @@ def _VOP3POp_V_DOT2_F32_BF16(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VG
D0.f32 = tmp D0.f32 = tmp
return {'D0': D0} return {'D0': D0}
def _VOP3POp_V_FMA_MIX_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0, OPSEL=0, OPSEL_HI=0):
# declare in : 32'F[3];
# declare S : 32'B[3];
# for i in 0 : 2 do
# if !OPSEL_HI.u3[i] then
# in[i] = S[i].f32
# elsif OPSEL.u3[i] then
# in[i] = f16_to_f32(S[i][31 : 16].f16)
# else
# in[i] = f16_to_f32(S[i][15 : 0].f16)
# endif
# endfor;
# D0[31 : 0].f32 = fma(in[0], in[1], in[2])
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
# --- compiled pseudocode ---
in_ = [Reg(0) for _ in range(3)]
S = [S0, S1, S2]
for i in range(0, int(2)+1):
if not ((OPSEL_HI >> i) & 1):
in_[i] = S[i].f32
elif ((OPSEL >> i) & 1):
in_[i] = f16_to_f32(S[i][31 : 16].f16)
else:
in_[i] = f16_to_f32(S[i][15 : 0].f16)
D0[31 : 0].f32 = fma(in_[0], in_[1], in_[2])
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3POp_V_FMA_MIXLO_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0, OPSEL=0, OPSEL_HI=0):
# declare in : 32'F[3];
# declare S : 32'B[3];
# for i in 0 : 2 do
# if !OPSEL_HI.u3[i] then
# in[i] = S[i].f32
# elsif OPSEL.u3[i] then
# in[i] = f16_to_f32(S[i][31 : 16].f16)
# else
# in[i] = f16_to_f32(S[i][15 : 0].f16)
# endif
# endfor;
# D0[15 : 0].f16 = f32_to_f16(fma(in[0], in[1], in[2]))
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
# --- compiled pseudocode ---
in_ = [Reg(0) for _ in range(3)]
S = [S0, S1, S2]
for i in range(0, int(2)+1):
if not ((OPSEL_HI >> i) & 1):
in_[i] = S[i].f32
elif ((OPSEL >> i) & 1):
in_[i] = f16_to_f32(S[i][31 : 16].f16)
else:
in_[i] = f16_to_f32(S[i][15 : 0].f16)
D0[15 : 0].f16 = f32_to_f16(fma(in_[0], in_[1], in_[2]))
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3POp_V_FMA_MIXHI_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0, OPSEL=0, OPSEL_HI=0):
# declare in : 32'F[3];
# declare S : 32'B[3];
# for i in 0 : 2 do
# if !OPSEL_HI.u3[i] then
# in[i] = S[i].f32
# elsif OPSEL.u3[i] then
# in[i] = f16_to_f32(S[i][31 : 16].f16)
# else
# in[i] = f16_to_f32(S[i][15 : 0].f16)
# endif
# endfor;
# D0[31 : 16].f16 = f32_to_f16(fma(in[0], in[1], in[2]))
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
# --- compiled pseudocode ---
in_ = [Reg(0) for _ in range(3)]
S = [S0, S1, S2]
for i in range(0, int(2)+1):
if not ((OPSEL_HI >> i) & 1):
in_[i] = S[i].f32
elif ((OPSEL >> i) & 1):
in_[i] = f16_to_f32(S[i][31 : 16].f16)
else:
in_[i] = f16_to_f32(S[i][15 : 0].f16)
D0[31 : 16].f16 = f32_to_f16(fma(in_[0], in_[1], in_[2]))
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
VOP3POp_FUNCTIONS = { VOP3POp_FUNCTIONS = {
VOP3POp.V_PK_MAD_I16: _VOP3POp_V_PK_MAD_I16, VOP3POp.V_PK_MAD_I16: _VOP3POp_V_PK_MAD_I16,
VOP3POp.V_PK_MUL_LO_U16: _VOP3POp_V_PK_MUL_LO_U16, VOP3POp.V_PK_MUL_LO_U16: _VOP3POp_V_PK_MUL_LO_U16,
@ -5227,6 +5323,9 @@ VOP3POp_FUNCTIONS = {
VOP3POp.V_DOT4_U32_U8: _VOP3POp_V_DOT4_U32_U8, VOP3POp.V_DOT4_U32_U8: _VOP3POp_V_DOT4_U32_U8,
VOP3POp.V_DOT8_U32_U4: _VOP3POp_V_DOT8_U32_U4, VOP3POp.V_DOT8_U32_U4: _VOP3POp_V_DOT8_U32_U4,
VOP3POp.V_DOT2_F32_BF16: _VOP3POp_V_DOT2_F32_BF16, VOP3POp.V_DOT2_F32_BF16: _VOP3POp_V_DOT2_F32_BF16,
VOP3POp.V_FMA_MIX_F32: _VOP3POp_V_FMA_MIX_F32,
VOP3POp.V_FMA_MIXLO_F16: _VOP3POp_V_FMA_MIXLO_F16,
VOP3POp.V_FMA_MIXHI_F16: _VOP3POp_V_FMA_MIXHI_F16,
} }
def _VOPCOp_V_CMP_F_F16(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None): def _VOPCOp_V_CMP_F_F16(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):
@ -6261,6 +6360,119 @@ def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)} return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32 VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
# ═══════════════════════════════════════════════════════════════════════════════
# DS (Data Share) INSTRUCTIONS
# DS instructions operate on LDS (Local Data Share) memory
# They receive: addr (address), data0/data1 (data VGPRs), offset0/offset1 (byte offsets)
# They return: {'vdst': [...]} for loads, {'lds_writes': [...]} for stores
# ═══════════════════════════════════════════════════════════════════════════════
def _ds_load(lds, addr, offset, size, sign_extend=False):
"""Load from LDS memory. Returns list of 32-bit values."""
a = (addr + offset) & 0xffff
if size <= 4:
val = int.from_bytes(lds[a:a+size], 'little')
if sign_extend and size < 4:
# Sign extend from size*8 bits to 32 bits
sign_bit = 1 << (size * 8 - 1)
if val & sign_bit: val |= ~((1 << (size * 8)) - 1)
return [val & 0xffffffff]
# Multi-dword load
return [int.from_bytes(lds[a+i*4:a+i*4+4], 'little') for i in range(size // 4)]
def _ds_store(lds, addr, offset, values, size):
"""Store to LDS memory. values is list of 32-bit dwords, size is bytes per element."""
a = (addr + offset) & 0xffff
if size <= 4:
lds[a:a+size] = (values[0] & ((1 << (size * 8)) - 1)).to_bytes(size, 'little')
else:
for i, v in enumerate(values):
lds[a+i*4:a+i*4+4] = (v & 0xffffffff).to_bytes(4, 'little')
# Load operations: DS_LOAD_B32, DS_LOAD_B64, DS_LOAD_B128, DS_LOAD_U8, DS_LOAD_I8, DS_LOAD_U16, DS_LOAD_I16
def _DSOp_DS_LOAD_B32(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 4)}
def _DSOp_DS_LOAD_B64(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 8)}
def _DSOp_DS_LOAD_B128(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 16)}
def _DSOp_DS_LOAD_U8(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 1, sign_extend=False)}
def _DSOp_DS_LOAD_I8(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 1, sign_extend=True)}
def _DSOp_DS_LOAD_U16(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 2, sign_extend=False)}
def _DSOp_DS_LOAD_I16(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 2, sign_extend=True)}
# Store operations: DS_STORE_B32, DS_STORE_B64, DS_STORE_B128, DS_STORE_B8, DS_STORE_B16
def _DSOp_DS_STORE_B32(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0, data0, 4)
return {}
def _DSOp_DS_STORE_B64(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0, data0, 8)
return {}
def _DSOp_DS_STORE_B128(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0, data0, 16)
return {}
def _DSOp_DS_STORE_B8(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0, data0, 1)
return {}
def _DSOp_DS_STORE_B16(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0, data0, 2)
return {}
# 2-address operations: DS_LOAD_2ADDR_B32, DS_LOAD_2ADDR_B64, DS_STORE_2ADDR_B32, DS_STORE_2ADDR_B64
# Note: offsets are scaled by data size (4 for B32, 8 for B64)
def _DSOp_DS_LOAD_2ADDR_B32(lds, addr, data0, data1, vdst, offset0, offset1):
v0 = _ds_load(lds, addr, offset0 * 4, 4)
v1 = _ds_load(lds, addr, offset1 * 4, 4)
return {'vdst': v0 + v1}
def _DSOp_DS_LOAD_2ADDR_B64(lds, addr, data0, data1, vdst, offset0, offset1):
v0 = _ds_load(lds, addr, offset0 * 8, 8)
v1 = _ds_load(lds, addr, offset1 * 8, 8)
return {'vdst': v0 + v1}
def _DSOp_DS_STORE_2ADDR_B32(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0 * 4, data0, 4)
_ds_store(lds, addr, offset1 * 4, data1, 4)
return {}
def _DSOp_DS_STORE_2ADDR_B64(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0 * 8, data0, 8)
_ds_store(lds, addr, offset1 * 8, data1, 8)
return {}
DSOp_FUNCTIONS = {
DSOp.DS_LOAD_B32: _DSOp_DS_LOAD_B32,
DSOp.DS_LOAD_B64: _DSOp_DS_LOAD_B64,
DSOp.DS_LOAD_B128: _DSOp_DS_LOAD_B128,
DSOp.DS_LOAD_U8: _DSOp_DS_LOAD_U8,
DSOp.DS_LOAD_I8: _DSOp_DS_LOAD_I8,
DSOp.DS_LOAD_U16: _DSOp_DS_LOAD_U16,
DSOp.DS_LOAD_I16: _DSOp_DS_LOAD_I16,
DSOp.DS_STORE_B32: _DSOp_DS_STORE_B32,
DSOp.DS_STORE_B64: _DSOp_DS_STORE_B64,
DSOp.DS_STORE_B128: _DSOp_DS_STORE_B128,
DSOp.DS_STORE_B8: _DSOp_DS_STORE_B8,
DSOp.DS_STORE_B16: _DSOp_DS_STORE_B16,
DSOp.DS_LOAD_2ADDR_B32: _DSOp_DS_LOAD_2ADDR_B32,
DSOp.DS_LOAD_2ADDR_B64: _DSOp_DS_LOAD_2ADDR_B64,
DSOp.DS_STORE_2ADDR_B32: _DSOp_DS_STORE_2ADDR_B32,
DSOp.DS_STORE_2ADDR_B64: _DSOp_DS_STORE_2ADDR_B64,
}
COMPILED_FUNCTIONS = { COMPILED_FUNCTIONS = {
SOP1Op: SOP1Op_FUNCTIONS, SOP1Op: SOP1Op_FUNCTIONS,
SOP2Op: SOP2Op_FUNCTIONS, SOP2Op: SOP2Op_FUNCTIONS,
@ -6273,6 +6485,7 @@ COMPILED_FUNCTIONS = {
VOP3SDOp: VOP3SDOp_FUNCTIONS, VOP3SDOp: VOP3SDOp_FUNCTIONS,
VOP3POp: VOP3POp_FUNCTIONS, VOP3POp: VOP3POp_FUNCTIONS,
VOPCOp: VOPCOp_FUNCTIONS, VOPCOp: VOPCOp_FUNCTIONS,
DSOp: DSOp_FUNCTIONS,
} }
def get_compiled_functions(): return COMPILED_FUNCTIONS def get_compiled_functions(): return COMPILED_FUNCTIONS

View file

@ -7,7 +7,7 @@ from extra.assembly.amd.pcode import Reg
from extra.assembly.amd.asm import detect_format from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD, from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, VOPDOp) SrcEnum, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, GLOBALOp, FLATOp, DSOp, VOPDOp)
Program = dict[int, Inst] Program = dict[int, Inst]
WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256 WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256
@ -51,11 +51,6 @@ _D16_LOAD_MAP = {'LOAD_D16_U8': (1,0,0), 'LOAD_D16_I8': (1,1,0), 'LOAD_D16_B16':
_D16_STORE_MAP = {'STORE_D16_HI_B8': (1,1), 'STORE_D16_HI_B16': (2,1)} # (size, hi) _D16_STORE_MAP = {'STORE_D16_HI_B8': (1,1), 'STORE_D16_HI_B16': (2,1)} # (size, hi)
FLAT_D16_LOAD = _mem_ops([GLOBALOp, FLATOp], _D16_LOAD_MAP) FLAT_D16_LOAD = _mem_ops([GLOBALOp, FLATOp], _D16_LOAD_MAP)
FLAT_D16_STORE = _mem_ops([GLOBALOp, FLATOp], _D16_STORE_MAP) FLAT_D16_STORE = _mem_ops([GLOBALOp, FLATOp], _D16_STORE_MAP)
DS_LOAD = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0), DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)}
DS_STORE = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)}
# 2ADDR ops: load/store two values using offset0 and offset1
DS_LOAD_2ADDR = {DSOp.DS_LOAD_2ADDR_B32: 4, DSOp.DS_LOAD_2ADDR_B64: 8}
DS_STORE_2ADDR = {DSOp.DS_STORE_2ADDR_B32: 4, DSOp.DS_STORE_2ADDR_B64: 8}
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16} SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup) # VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
@ -225,32 +220,15 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
return return
if isinstance(inst, DS): if isinstance(inst, DS):
op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst fn = compiled.get(DSOp, {}).get(inst.op)
if op in DS_LOAD: if fn is None: raise NotImplementedError(f"DS op {inst.op.name} not in pseudocode")
cnt, sz, sign = DS_LOAD[op] # Prepare data registers as lists of dwords
for i in range(cnt): val = int.from_bytes(lds[addr0+i*sz:addr0+i*sz+sz], 'little'); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val data0 = [V[inst.data0 + i] for i in range(4)] # up to 4 dwords
elif op in DS_STORE: data1 = [V[inst.data1 + i] for i in range(4)] if inst.data1 else [0, 0, 0, 0]
cnt, sz = DS_STORE[op] result = fn(lds, V[inst.addr], data0, data1, inst.vdst, inst.offset0, inst.offset1)
for i in range(cnt): lds[addr0+i*sz:addr0+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little') # Write results for loads
elif op in DS_LOAD_2ADDR: if 'vdst' in result:
# Load two values from addr+offset0*sz and addr+offset1*sz into vdst (B32: 1 dword each, B64: 2 dwords each) for i, val in enumerate(result['vdst']): V[inst.vdst + i] = val & MASK32
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
sz = DS_LOAD_2ADDR[op]
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
cnt = sz // 4 # 1 for B32, 2 for B64
for i in range(cnt): V[vdst + i] = int.from_bytes(lds[addr0+i*4:addr0+i*4+4], 'little')
for i in range(cnt): V[vdst + cnt + i] = int.from_bytes(lds[addr1+i*4:addr1+i*4+4], 'little')
elif op in DS_STORE_2ADDR:
# Store two values from data0 and data1 to addr+offset0*sz and addr+offset1*sz
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
sz = DS_STORE_2ADDR[op]
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
cnt = sz // 4
for i in range(cnt): lds[addr0+i*4:addr0+i*4+4] = (V[inst.data0 + i] & MASK32).to_bytes(4, 'little')
for i in range(cnt): lds[addr1+i*4:addr1+i*4+4] = (V[inst.data1 + i] & MASK32).to_bytes(4, 'little')
else: raise NotImplementedError(f"DS op {op}")
return return
# VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes) # VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)
@ -306,17 +284,19 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
exec_wmma(st, inst, inst.op) exec_wmma(st, inst, inst.op)
return return
# V_FMA_MIX: Mixed precision FMA - opsel_hi controls f32(0) vs f16(1), opsel selects which f16 half # V_FMA_MIX: Mixed precision FMA - opsel_hi controls f32(0) vs f16(1), opsel selects which f16 half
if 'FMA_MIX' in inst.op_name: # Handle inline because abs/neg must be applied AFTER type conversion
if inst.op in (VOP3POp.V_FMA_MIX_F32, VOP3POp.V_FMA_MIXLO_F16, VOP3POp.V_FMA_MIXHI_F16):
opsel, opsel_hi, opsel_hi2 = getattr(inst, 'opsel', 0), getattr(inst, 'opsel_hi', 0), getattr(inst, 'opsel_hi2', 0) opsel, opsel_hi, opsel_hi2 = getattr(inst, 'opsel', 0), getattr(inst, 'opsel_hi', 0), getattr(inst, 'opsel_hi2', 0)
neg, abs_ = getattr(inst, 'neg', 0), getattr(inst, 'neg_hi', 0) # neg_hi reused as abs neg, abs_ = getattr(inst, 'neg', 0), getattr(inst, 'neg_hi', 0) # neg_hi reused as abs for FMA_MIX
raws = [st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane) if inst.src2 is not None else 0] raws = [st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane) if inst.src2 is not None else 0]
is_f16 = [opsel_hi & 1, opsel_hi & 2, opsel_hi2] is_f16 = [opsel_hi & 1, opsel_hi & 2, opsel_hi2]
srcs = [_f16(_src16(raws[i], bool(opsel & (1<<i)))) if is_f16[i] else _f32(raws[i]) for i in range(3)] srcs = [_f16(_src16(raws[i], bool(opsel & (1<<i)))) if is_f16[i] else _f32(raws[i]) for i in range(3)]
for i in range(3): for i in range(3):
if abs_ & (1<<i): srcs[i] = abs(srcs[i]) if abs_ & (1<<i): srcs[i] = abs(srcs[i])
if neg & (1<<i): srcs[i] = -srcs[i] if neg & (1<<i): srcs[i] = -srcs[i]
result = srcs[0] * srcs[1] + srcs[2] result_f = srcs[0] * srcs[1] + srcs[2]
st.vgpr[lane][inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16) V = st.vgpr[lane]
V[inst.vdst] = _i32(result_f) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result_f), inst.op == VOP3POp.V_FMA_MIXHI_F16)
return return
# VOP3P packed ops: opsel selects halves for lo, opsel_hi for hi; neg toggles f16 sign # VOP3P packed ops: opsel selects halves for lo, opsel_hi for hi; neg toggles f16 sign
raws = [st.rsrc_f16(inst.src0, lane), st.rsrc_f16(inst.src1, lane), st.rsrc_f16(inst.src2, lane) if inst.src2 is not None else 0] raws = [st.rsrc_f16(inst.src0, lane), st.rsrc_f16(inst.src1, lane), st.rsrc_f16(inst.src2, lane) if inst.src2 is not None else 0]

View file

@ -36,13 +36,13 @@ FIELD_ORDER = {
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'} SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO', FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'} '4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M) INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
# Patterns that can't be handled by the DSL (require special handling in emu.py) # Patterns that can't be handled by the DSL (require special handling in emu.py)
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS', UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'CVT_OFF_TABLE', 'ThreadMask', 'CVT_OFF_TABLE', 'ThreadMask',
'S1[i', 'C.i32', 'S[i]', 'in[', 'S1[i', 'C.i32',
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST', 'if n.', 'DST.u32', 'addrd = DST', 'addr = DST',
'BARRIER_STATE', 'ReallocVgprs', 'BARRIER_STATE', 'ReallocVgprs',
'GPR_IDX', 'VSKIP', 'specified in', 'TTBL', 'GPR_IDX', 'VSKIP', 'specified in', 'TTBL',
@ -67,17 +67,18 @@ def compile_pseudocode(pseudocode: str) -> str:
lines = [] lines = []
indent, need_pass, in_first_match_loop = 0, False, False indent, need_pass, in_first_match_loop = 0, False, False
declared_arrays: dict[str, int] = {} # Track declared arrays: name -> size
for line in joined_lines: for line in joined_lines:
line = line.strip() line = line.strip()
if not line or line.startswith('//'): continue if not line or line.startswith('//'): continue
if line.startswith('if '): if line.startswith('if '):
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:") lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'), declared_arrays)}:")
indent += 1 indent += 1
need_pass = True need_pass = True
elif line.startswith('elsif '): elif line.startswith('elsif '):
if need_pass: lines.append(' ' * indent + "pass") if need_pass: lines.append(' ' * indent + "pass")
indent -= 1 indent -= 1
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:") lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'), declared_arrays)}:")
indent += 1 indent += 1
need_pass = True need_pass = True
elif line == 'else': elif line == 'else':
@ -94,10 +95,19 @@ def compile_pseudocode(pseudocode: str) -> str:
if need_pass: lines.append(' ' * indent + "pass") if need_pass: lines.append(' ' * indent + "pass")
indent -= 1 indent -= 1
need_pass, in_first_match_loop = False, False need_pass, in_first_match_loop = False, False
elif m := re.match(r'declare\s+(\w+)\s*:\s*\d+\'[FBU]\[(\d+)\]', line):
# Handle array declarations: declare in : 32'F[3] or declare S : 32'B[3]
arr_name, arr_size = m[1], int(m[2])
declared_arrays[arr_name] = arr_size
py_name = f"{arr_name}_" if arr_name == 'in' else arr_name # 'in' is Python keyword
if arr_name == 'S':
lines.append(' ' * indent + f"{py_name} = [S0, S1, S2]") # Map to source registers
else:
lines.append(' ' * indent + f"{py_name} = [Reg(0) for _ in range({arr_size})]")
elif line.startswith('declare '): elif line.startswith('declare '):
pass pass # Ignore other declare statements
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line): elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
start, end = _expr(m[2].strip()), _expr(m[3].strip()) start, end = _expr(m[2].strip(), declared_arrays), _expr(m[3].strip(), declared_arrays)
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):") lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
indent += 1 indent += 1
need_pass, in_first_match_loop = True, True need_pass, in_first_match_loop = True, True
@ -105,7 +115,7 @@ def compile_pseudocode(pseudocode: str) -> str:
need_pass = False need_pass = False
line = line.rstrip(';') line = line.rstrip(';')
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line): if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
rhs = _expr(m[1]) rhs = _expr(m[1], declared_arrays)
lines.append(' ' * indent + f"_full = {rhs}") lines.append(' ' * indent + f"_full = {rhs}")
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff") lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)") lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
@ -113,30 +123,39 @@ def compile_pseudocode(pseudocode: str) -> str:
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='): for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
if op in line: if op in line:
lhs, rhs = line.split(op, 1) lhs, rhs = line.split(op, 1)
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}") lhs_s = _expr(lhs.strip(), declared_arrays) # Transform LHS too for array access
lines.append(' ' * indent + f"{lhs_s} {op} {_expr(rhs.strip(), declared_arrays)}")
break break
else: else:
lhs, rhs = line.split('=', 1) lhs, rhs = line.split('=', 1)
lhs_s, rhs_s = _expr(lhs.strip()), rhs.strip() lhs_s, rhs_s = lhs.strip(), rhs.strip()
stmt = _assign(lhs_s, _expr(rhs_s)) lhs_t = _expr(lhs_s, declared_arrays) # Transform LHS for array access
stmt = _assign(lhs_t, _expr(rhs_s, declared_arrays), declared_arrays)
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'): if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
stmt += "; break" stmt += "; break"
lines.append(' ' * indent + stmt) lines.append(' ' * indent + stmt)
if need_pass: lines.append(' ' * indent + "pass") if need_pass: lines.append(' ' * indent + "pass")
return '\n'.join(lines) return '\n'.join(lines)
def _assign(lhs: str, rhs: str) -> str: def _assign(lhs: str, rhs: str, declared_arrays: dict[str, int] | None = None) -> str:
# Check for array element assignment: in_[i] should not wrap in Reg()
if declared_arrays and re.match(r'\w+_?\[\w+\]', lhs):
return f"{lhs} = {rhs}"
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'): if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
return f"{lhs} = Reg({rhs})" return f"{lhs} = Reg({rhs})"
return f"{lhs} = {rhs}" return f"{lhs} = {rhs}"
def _expr(e: str) -> str: def _expr(e: str, declared_arrays: dict[str, int] | None = None) -> str:
e = e.strip() e = e.strip()
# Handle OPSEL_HI.u3[i] and OPSEL.u3[i] - bit extraction from opsel fields
e = re.sub(r'(OPSEL(?:_HI)?)\.u\d+\[(\w+)\]', r'((\1 >> \2) & 1)', e)
# Rename 'in' to 'in_' to avoid Python keyword conflict
e = re.sub(r'\bin\[', 'in_[', e)
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ') e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
e = re.sub(r'!([^=])', r' not \1', e) e = re.sub(r'!([^=])', r' not \1', e)
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e) e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
def pack(m): def pack(m):
hi, lo = _expr(m[1].strip()), _expr(m[2].strip()) hi, lo = _expr(m[1].strip(), declared_arrays), _expr(m[2].strip(), declared_arrays)
return f'_pack({hi}, {lo})' return f'_pack({hi}, {lo})'
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e) e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e) e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
@ -164,7 +183,7 @@ def _expr(e: str) -> str:
if s[j] == '[': depth += 1 if s[j] == '[': depth += 1
elif s[j] == ']': depth -= 1 elif s[j] == ']': depth -= 1
j += 1 j += 1
inner = _expr(s[start:j-1]) inner = _expr(s[start:j-1], declared_arrays)
result.append('[' + inner + ']') result.append('[' + inner + ']')
i = j i = j
else: else:
@ -362,7 +381,7 @@ def _extract_pseudocode(text: str) -> str | None:
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
is_code = (any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =', is_code = (any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =',
'D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or 'D0[', 'D1[', 'S0[', 'S1[', 'S2[', 'MEM[', 'RETURN_DATA', 'DATA.', 'DATA0', 'DATA1', 'ADDR']) or
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s)) re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s))
if is_code: result.append(s) if is_code: result.append(s)
@ -448,13 +467,13 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
# Get op enums for this arch (import from .ins which re-exports from .enum) # Get op enums for this arch (import from .ins which re-exports from .enum)
import importlib import importlib
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins") autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp'] if hasattr(autogen, name)] OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp', 'DSOp'] if hasattr(autogen, name)]
# Build defined ops mapping # Build defined ops mapping
defined_ops: dict[tuple, list] = {} defined_ops: dict[tuple, list] = {}
for enum_cls in OP_ENUMS: for enum_cls in OP_ENUMS:
for op in enum_cls: for op in enum_cls:
if op.name.startswith(('S_', 'V_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op)) if op.name.startswith(('S_', 'V_', 'DS_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
enum_names = [e.__name__ for e in OP_ENUMS] enum_names = [e.__name__ for e in OP_ENUMS]
lines = [f'''# autogenerated by pdf.py - do not edit lines = [f'''# autogenerated by pdf.py - do not edit
@ -541,11 +560,14 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
is_div_scale = 'DIV_SCALE' in op.name is_div_scale = 'DIV_SCALE' in op.name
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale) has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
has_opsel = 'OPSEL' in pc # FMA_MIX and similar instructions need OPSEL/OPSEL_HI
combined = code + pc combined = code + pc
fn_name = f"_{cls_name}_{op.name}" fn_name = f"_{cls_name}_{op.name}"
# Function accepts Reg objects directly (uppercase names), laneId is passed directly as int # Function accepts Reg objects directly (uppercase names), laneId is passed directly as int
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"] params = "S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None"
if has_opsel: params += ", OPSEL=0, OPSEL_HI=0"
lines = [f"def {fn_name}({params}):"]
# Registers that need special handling (not passed directly) # Registers that need special handling (not passed directly)
# Only init if used but not first assigned as `name = Reg(...)` in the compiled code # Only init if used but not first assigned as `name = Reg(...)` in the compiled code