Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
aec4d65241 ds compiled 2025-12-31 15:43:44 -05:00
George Hotz
f022a7d8a7 assembly/amd: move more instructions to pcode 2025-12-31 15:42:59 -05:00
4 changed files with 271 additions and 56 deletions

View file

@ -28,4 +28,4 @@ The ops tests also pass, but they are very slow, so you should run them one at a
When something is caught by main tinygrad tests, a local regression test should be added to `extra/assembly/amd/test`. While working with tinygrad, you can dump the assembly with `DEBUG=7`. These tests all pass on real hardware, so if a test is failing with `AMD=1 PYTHON_REMU=1 MOCKGPU=1` it's likely because an instruction is emulated incorrectly. You can test without `MOCKGPU=1` to test on real hardware, if it works on real hardware there's a bug in the emulator.
Currently, only RDNA3 is well supported, but when finished, this will support RDNA3+RDNA4+CDNA in ~2000 lines.
Currently, only RDNA3 is well supported, but when finished, this will support RDNA3+RDNA4+CDNA in ~2000 lines. Count lines with `cloc --by-file extra/assembly/amd/*.py`

View file

@ -2,7 +2,7 @@
# to regenerate: python -m extra.assembly.amd.pdf --arch rdna3
# ruff: noqa: E501,F405,F403
# mypy: ignore-errors
from extra.assembly.amd.autogen.rdna3.enum import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp
from extra.assembly.amd.autogen.rdna3.enum import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp
from extra.assembly.amd.pcode import *
def _SOP1Op_S_MOV_B32(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):
@ -5203,6 +5203,102 @@ def _VOP3POp_V_DOT2_F32_BF16(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VG
D0.f32 = tmp
return {'D0': D0}
def _VOP3POp_V_FMA_MIX_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0, OPSEL=0, OPSEL_HI=0):
# declare in : 32'F[3];
# declare S : 32'B[3];
# for i in 0 : 2 do
# if !OPSEL_HI.u3[i] then
# in[i] = S[i].f32
# elsif OPSEL.u3[i] then
# in[i] = f16_to_f32(S[i][31 : 16].f16)
# else
# in[i] = f16_to_f32(S[i][15 : 0].f16)
# endif
# endfor;
# D0[31 : 0].f32 = fma(in[0], in[1], in[2])
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
# --- compiled pseudocode ---
in_ = [Reg(0) for _ in range(3)]
S = [S0, S1, S2]
for i in range(0, int(2)+1):
if not ((OPSEL_HI >> i) & 1):
in_[i] = S[i].f32
elif ((OPSEL >> i) & 1):
in_[i] = f16_to_f32(S[i][31 : 16].f16)
else:
in_[i] = f16_to_f32(S[i][15 : 0].f16)
D0[31 : 0].f32 = fma(in_[0], in_[1], in_[2])
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3POp_V_FMA_MIXLO_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0, OPSEL=0, OPSEL_HI=0):
# declare in : 32'F[3];
# declare S : 32'B[3];
# for i in 0 : 2 do
# if !OPSEL_HI.u3[i] then
# in[i] = S[i].f32
# elsif OPSEL.u3[i] then
# in[i] = f16_to_f32(S[i][31 : 16].f16)
# else
# in[i] = f16_to_f32(S[i][15 : 0].f16)
# endif
# endfor;
# D0[15 : 0].f16 = f32_to_f16(fma(in[0], in[1], in[2]))
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
# --- compiled pseudocode ---
in_ = [Reg(0) for _ in range(3)]
S = [S0, S1, S2]
for i in range(0, int(2)+1):
if not ((OPSEL_HI >> i) & 1):
in_[i] = S[i].f32
elif ((OPSEL >> i) & 1):
in_[i] = f16_to_f32(S[i][31 : 16].f16)
else:
in_[i] = f16_to_f32(S[i][15 : 0].f16)
D0[15 : 0].f16 = f32_to_f16(fma(in_[0], in_[1], in_[2]))
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3POp_V_FMA_MIXHI_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0, OPSEL=0, OPSEL_HI=0):
# declare in : 32'F[3];
# declare S : 32'B[3];
# for i in 0 : 2 do
# if !OPSEL_HI.u3[i] then
# in[i] = S[i].f32
# elsif OPSEL.u3[i] then
# in[i] = f16_to_f32(S[i][31 : 16].f16)
# else
# in[i] = f16_to_f32(S[i][15 : 0].f16)
# endif
# endfor;
# D0[31 : 16].f16 = f32_to_f16(fma(in[0], in[1], in[2]))
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
# --- compiled pseudocode ---
in_ = [Reg(0) for _ in range(3)]
S = [S0, S1, S2]
for i in range(0, int(2)+1):
if not ((OPSEL_HI >> i) & 1):
in_[i] = S[i].f32
elif ((OPSEL >> i) & 1):
in_[i] = f16_to_f32(S[i][31 : 16].f16)
else:
in_[i] = f16_to_f32(S[i][15 : 0].f16)
D0[31 : 16].f16 = f32_to_f16(fma(in_[0], in_[1], in_[2]))
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
VOP3POp_FUNCTIONS = {
VOP3POp.V_PK_MAD_I16: _VOP3POp_V_PK_MAD_I16,
VOP3POp.V_PK_MUL_LO_U16: _VOP3POp_V_PK_MUL_LO_U16,
@ -5227,6 +5323,9 @@ VOP3POp_FUNCTIONS = {
VOP3POp.V_DOT4_U32_U8: _VOP3POp_V_DOT4_U32_U8,
VOP3POp.V_DOT8_U32_U4: _VOP3POp_V_DOT8_U32_U4,
VOP3POp.V_DOT2_F32_BF16: _VOP3POp_V_DOT2_F32_BF16,
VOP3POp.V_FMA_MIX_F32: _VOP3POp_V_FMA_MIX_F32,
VOP3POp.V_FMA_MIXLO_F16: _VOP3POp_V_FMA_MIXLO_F16,
VOP3POp.V_FMA_MIXHI_F16: _VOP3POp_V_FMA_MIXHI_F16,
}
def _VOPCOp_V_CMP_F_F16(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):
@ -6261,6 +6360,119 @@ def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
# ═══════════════════════════════════════════════════════════════════════════════
# DS (Data Share) INSTRUCTIONS
# DS instructions operate on LDS (Local Data Share) memory
# They receive: addr (address), data0/data1 (data VGPRs), offset0/offset1 (byte offsets)
# They return: {'vdst': [...]} for loads, {'lds_writes': [...]} for stores
# ═══════════════════════════════════════════════════════════════════════════════
def _ds_load(lds, addr, offset, size, sign_extend=False):
"""Load from LDS memory. Returns list of 32-bit values."""
a = (addr + offset) & 0xffff
if size <= 4:
val = int.from_bytes(lds[a:a+size], 'little')
if sign_extend and size < 4:
# Sign extend from size*8 bits to 32 bits
sign_bit = 1 << (size * 8 - 1)
if val & sign_bit: val |= ~((1 << (size * 8)) - 1)
return [val & 0xffffffff]
# Multi-dword load
return [int.from_bytes(lds[a+i*4:a+i*4+4], 'little') for i in range(size // 4)]
def _ds_store(lds, addr, offset, values, size):
"""Store to LDS memory. values is list of 32-bit dwords, size is bytes per element."""
a = (addr + offset) & 0xffff
if size <= 4:
lds[a:a+size] = (values[0] & ((1 << (size * 8)) - 1)).to_bytes(size, 'little')
else:
for i, v in enumerate(values):
lds[a+i*4:a+i*4+4] = (v & 0xffffffff).to_bytes(4, 'little')
# Load operations: DS_LOAD_B32, DS_LOAD_B64, DS_LOAD_B128, DS_LOAD_U8, DS_LOAD_I8, DS_LOAD_U16, DS_LOAD_I16
def _DSOp_DS_LOAD_B32(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 4)}
def _DSOp_DS_LOAD_B64(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 8)}
def _DSOp_DS_LOAD_B128(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 16)}
def _DSOp_DS_LOAD_U8(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 1, sign_extend=False)}
def _DSOp_DS_LOAD_I8(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 1, sign_extend=True)}
def _DSOp_DS_LOAD_U16(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 2, sign_extend=False)}
def _DSOp_DS_LOAD_I16(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 2, sign_extend=True)}
# Store operations: DS_STORE_B32, DS_STORE_B64, DS_STORE_B128, DS_STORE_B8, DS_STORE_B16
def _DSOp_DS_STORE_B32(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0, data0, 4)
return {}
def _DSOp_DS_STORE_B64(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0, data0, 8)
return {}
def _DSOp_DS_STORE_B128(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0, data0, 16)
return {}
def _DSOp_DS_STORE_B8(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0, data0, 1)
return {}
def _DSOp_DS_STORE_B16(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0, data0, 2)
return {}
# 2-address operations: DS_LOAD_2ADDR_B32, DS_LOAD_2ADDR_B64, DS_STORE_2ADDR_B32, DS_STORE_2ADDR_B64
# Note: offsets are scaled by data size (4 for B32, 8 for B64)
def _DSOp_DS_LOAD_2ADDR_B32(lds, addr, data0, data1, vdst, offset0, offset1):
v0 = _ds_load(lds, addr, offset0 * 4, 4)
v1 = _ds_load(lds, addr, offset1 * 4, 4)
return {'vdst': v0 + v1}
def _DSOp_DS_LOAD_2ADDR_B64(lds, addr, data0, data1, vdst, offset0, offset1):
v0 = _ds_load(lds, addr, offset0 * 8, 8)
v1 = _ds_load(lds, addr, offset1 * 8, 8)
return {'vdst': v0 + v1}
def _DSOp_DS_STORE_2ADDR_B32(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0 * 4, data0, 4)
_ds_store(lds, addr, offset1 * 4, data1, 4)
return {}
def _DSOp_DS_STORE_2ADDR_B64(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0 * 8, data0, 8)
_ds_store(lds, addr, offset1 * 8, data1, 8)
return {}
DSOp_FUNCTIONS = {
DSOp.DS_LOAD_B32: _DSOp_DS_LOAD_B32,
DSOp.DS_LOAD_B64: _DSOp_DS_LOAD_B64,
DSOp.DS_LOAD_B128: _DSOp_DS_LOAD_B128,
DSOp.DS_LOAD_U8: _DSOp_DS_LOAD_U8,
DSOp.DS_LOAD_I8: _DSOp_DS_LOAD_I8,
DSOp.DS_LOAD_U16: _DSOp_DS_LOAD_U16,
DSOp.DS_LOAD_I16: _DSOp_DS_LOAD_I16,
DSOp.DS_STORE_B32: _DSOp_DS_STORE_B32,
DSOp.DS_STORE_B64: _DSOp_DS_STORE_B64,
DSOp.DS_STORE_B128: _DSOp_DS_STORE_B128,
DSOp.DS_STORE_B8: _DSOp_DS_STORE_B8,
DSOp.DS_STORE_B16: _DSOp_DS_STORE_B16,
DSOp.DS_LOAD_2ADDR_B32: _DSOp_DS_LOAD_2ADDR_B32,
DSOp.DS_LOAD_2ADDR_B64: _DSOp_DS_LOAD_2ADDR_B64,
DSOp.DS_STORE_2ADDR_B32: _DSOp_DS_STORE_2ADDR_B32,
DSOp.DS_STORE_2ADDR_B64: _DSOp_DS_STORE_2ADDR_B64,
}
COMPILED_FUNCTIONS = {
SOP1Op: SOP1Op_FUNCTIONS,
SOP2Op: SOP2Op_FUNCTIONS,
@ -6273,6 +6485,7 @@ COMPILED_FUNCTIONS = {
VOP3SDOp: VOP3SDOp_FUNCTIONS,
VOP3POp: VOP3POp_FUNCTIONS,
VOPCOp: VOPCOp_FUNCTIONS,
DSOp: DSOp_FUNCTIONS,
}
def get_compiled_functions(): return COMPILED_FUNCTIONS

View file

@ -7,7 +7,7 @@ from extra.assembly.amd.pcode import Reg
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, VOPDOp)
SrcEnum, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, GLOBALOp, FLATOp, DSOp, VOPDOp)
Program = dict[int, Inst]
WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256
@ -51,11 +51,6 @@ _D16_LOAD_MAP = {'LOAD_D16_U8': (1,0,0), 'LOAD_D16_I8': (1,1,0), 'LOAD_D16_B16':
_D16_STORE_MAP = {'STORE_D16_HI_B8': (1,1), 'STORE_D16_HI_B16': (2,1)} # (size, hi)
FLAT_D16_LOAD = _mem_ops([GLOBALOp, FLATOp], _D16_LOAD_MAP)
FLAT_D16_STORE = _mem_ops([GLOBALOp, FLATOp], _D16_STORE_MAP)
DS_LOAD = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0), DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)}
DS_STORE = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)}
# 2ADDR ops: load/store two values using offset0 and offset1
DS_LOAD_2ADDR = {DSOp.DS_LOAD_2ADDR_B32: 4, DSOp.DS_LOAD_2ADDR_B64: 8}
DS_STORE_2ADDR = {DSOp.DS_STORE_2ADDR_B32: 4, DSOp.DS_STORE_2ADDR_B64: 8}
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
@ -225,32 +220,15 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
return
if isinstance(inst, DS):
op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst
if op in DS_LOAD:
cnt, sz, sign = DS_LOAD[op]
for i in range(cnt): val = int.from_bytes(lds[addr0+i*sz:addr0+i*sz+sz], 'little'); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
elif op in DS_STORE:
cnt, sz = DS_STORE[op]
for i in range(cnt): lds[addr0+i*sz:addr0+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little')
elif op in DS_LOAD_2ADDR:
# Load two values from addr+offset0*sz and addr+offset1*sz into vdst (B32: 1 dword each, B64: 2 dwords each)
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
sz = DS_LOAD_2ADDR[op]
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
cnt = sz // 4 # 1 for B32, 2 for B64
for i in range(cnt): V[vdst + i] = int.from_bytes(lds[addr0+i*4:addr0+i*4+4], 'little')
for i in range(cnt): V[vdst + cnt + i] = int.from_bytes(lds[addr1+i*4:addr1+i*4+4], 'little')
elif op in DS_STORE_2ADDR:
# Store two values from data0 and data1 to addr+offset0*sz and addr+offset1*sz
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
sz = DS_STORE_2ADDR[op]
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
cnt = sz // 4
for i in range(cnt): lds[addr0+i*4:addr0+i*4+4] = (V[inst.data0 + i] & MASK32).to_bytes(4, 'little')
for i in range(cnt): lds[addr1+i*4:addr1+i*4+4] = (V[inst.data1 + i] & MASK32).to_bytes(4, 'little')
else: raise NotImplementedError(f"DS op {op}")
fn = compiled.get(DSOp, {}).get(inst.op)
if fn is None: raise NotImplementedError(f"DS op {inst.op.name} not in pseudocode")
# Prepare data registers as lists of dwords
data0 = [V[inst.data0 + i] for i in range(4)] # up to 4 dwords
data1 = [V[inst.data1 + i] for i in range(4)] if inst.data1 else [0, 0, 0, 0]
result = fn(lds, V[inst.addr], data0, data1, inst.vdst, inst.offset0, inst.offset1)
# Write results for loads
if 'vdst' in result:
for i, val in enumerate(result['vdst']): V[inst.vdst + i] = val & MASK32
return
# VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)
@ -306,17 +284,19 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
exec_wmma(st, inst, inst.op)
return
# V_FMA_MIX: Mixed precision FMA - opsel_hi controls f32(0) vs f16(1), opsel selects which f16 half
if 'FMA_MIX' in inst.op_name:
# Handle inline because abs/neg must be applied AFTER type conversion
if inst.op in (VOP3POp.V_FMA_MIX_F32, VOP3POp.V_FMA_MIXLO_F16, VOP3POp.V_FMA_MIXHI_F16):
opsel, opsel_hi, opsel_hi2 = getattr(inst, 'opsel', 0), getattr(inst, 'opsel_hi', 0), getattr(inst, 'opsel_hi2', 0)
neg, abs_ = getattr(inst, 'neg', 0), getattr(inst, 'neg_hi', 0) # neg_hi reused as abs
neg, abs_ = getattr(inst, 'neg', 0), getattr(inst, 'neg_hi', 0) # neg_hi reused as abs for FMA_MIX
raws = [st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane) if inst.src2 is not None else 0]
is_f16 = [opsel_hi & 1, opsel_hi & 2, opsel_hi2]
srcs = [_f16(_src16(raws[i], bool(opsel & (1<<i)))) if is_f16[i] else _f32(raws[i]) for i in range(3)]
for i in range(3):
if abs_ & (1<<i): srcs[i] = abs(srcs[i])
if neg & (1<<i): srcs[i] = -srcs[i]
result = srcs[0] * srcs[1] + srcs[2]
st.vgpr[lane][inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
result_f = srcs[0] * srcs[1] + srcs[2]
V = st.vgpr[lane]
V[inst.vdst] = _i32(result_f) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result_f), inst.op == VOP3POp.V_FMA_MIXHI_F16)
return
# VOP3P packed ops: opsel selects halves for lo, opsel_hi for hi; neg toggles f16 sign
raws = [st.rsrc_f16(inst.src0, lane), st.rsrc_f16(inst.src1, lane), st.rsrc_f16(inst.src2, lane) if inst.src2 is not None else 0]

View file

@ -36,13 +36,13 @@ FIELD_ORDER = {
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
# Patterns that can't be handled by the DSL (require special handling in emu.py)
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'CVT_OFF_TABLE', 'ThreadMask',
'S1[i', 'C.i32', 'S[i]', 'in[',
'S1[i', 'C.i32',
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST',
'BARRIER_STATE', 'ReallocVgprs',
'GPR_IDX', 'VSKIP', 'specified in', 'TTBL',
@ -67,17 +67,18 @@ def compile_pseudocode(pseudocode: str) -> str:
lines = []
indent, need_pass, in_first_match_loop = 0, False, False
declared_arrays: dict[str, int] = {} # Track declared arrays: name -> size
for line in joined_lines:
line = line.strip()
if not line or line.startswith('//'): continue
if line.startswith('if '):
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'), declared_arrays)}:")
indent += 1
need_pass = True
elif line.startswith('elsif '):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:")
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'), declared_arrays)}:")
indent += 1
need_pass = True
elif line == 'else':
@ -94,10 +95,19 @@ def compile_pseudocode(pseudocode: str) -> str:
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
need_pass, in_first_match_loop = False, False
elif m := re.match(r'declare\s+(\w+)\s*:\s*\d+\'[FBU]\[(\d+)\]', line):
# Handle array declarations: declare in : 32'F[3] or declare S : 32'B[3]
arr_name, arr_size = m[1], int(m[2])
declared_arrays[arr_name] = arr_size
py_name = f"{arr_name}_" if arr_name == 'in' else arr_name # 'in' is Python keyword
if arr_name == 'S':
lines.append(' ' * indent + f"{py_name} = [S0, S1, S2]") # Map to source registers
else:
lines.append(' ' * indent + f"{py_name} = [Reg(0) for _ in range({arr_size})]")
elif line.startswith('declare '):
pass
pass # Ignore other declare statements
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
start, end = _expr(m[2].strip()), _expr(m[3].strip())
start, end = _expr(m[2].strip(), declared_arrays), _expr(m[3].strip(), declared_arrays)
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
indent += 1
need_pass, in_first_match_loop = True, True
@ -105,7 +115,7 @@ def compile_pseudocode(pseudocode: str) -> str:
need_pass = False
line = line.rstrip(';')
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
rhs = _expr(m[1])
rhs = _expr(m[1], declared_arrays)
lines.append(' ' * indent + f"_full = {rhs}")
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
@ -113,30 +123,39 @@ def compile_pseudocode(pseudocode: str) -> str:
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
if op in line:
lhs, rhs = line.split(op, 1)
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}")
lhs_s = _expr(lhs.strip(), declared_arrays) # Transform LHS too for array access
lines.append(' ' * indent + f"{lhs_s} {op} {_expr(rhs.strip(), declared_arrays)}")
break
else:
lhs, rhs = line.split('=', 1)
lhs_s, rhs_s = _expr(lhs.strip()), rhs.strip()
stmt = _assign(lhs_s, _expr(rhs_s))
lhs_s, rhs_s = lhs.strip(), rhs.strip()
lhs_t = _expr(lhs_s, declared_arrays) # Transform LHS for array access
stmt = _assign(lhs_t, _expr(rhs_s, declared_arrays), declared_arrays)
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
stmt += "; break"
lines.append(' ' * indent + stmt)
if need_pass: lines.append(' ' * indent + "pass")
return '\n'.join(lines)
def _assign(lhs: str, rhs: str) -> str:
def _assign(lhs: str, rhs: str, declared_arrays: dict[str, int] | None = None) -> str:
# Check for array element assignment: in_[i] should not wrap in Reg()
if declared_arrays and re.match(r'\w+_?\[\w+\]', lhs):
return f"{lhs} = {rhs}"
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
return f"{lhs} = Reg({rhs})"
return f"{lhs} = {rhs}"
def _expr(e: str) -> str:
def _expr(e: str, declared_arrays: dict[str, int] | None = None) -> str:
e = e.strip()
# Handle OPSEL_HI.u3[i] and OPSEL.u3[i] - bit extraction from opsel fields
e = re.sub(r'(OPSEL(?:_HI)?)\.u\d+\[(\w+)\]', r'((\1 >> \2) & 1)', e)
# Rename 'in' to 'in_' to avoid Python keyword conflict
e = re.sub(r'\bin\[', 'in_[', e)
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
e = re.sub(r'!([^=])', r' not \1', e)
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
def pack(m):
hi, lo = _expr(m[1].strip()), _expr(m[2].strip())
hi, lo = _expr(m[1].strip(), declared_arrays), _expr(m[2].strip(), declared_arrays)
return f'_pack({hi}, {lo})'
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
@ -164,7 +183,7 @@ def _expr(e: str) -> str:
if s[j] == '[': depth += 1
elif s[j] == ']': depth -= 1
j += 1
inner = _expr(s[start:j-1])
inner = _expr(s[start:j-1], declared_arrays)
result.append('[' + inner + ']')
i = j
else:
@ -362,7 +381,7 @@ def _extract_pseudocode(text: str) -> str | None:
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
is_code = (any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =',
'D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
'D0[', 'D1[', 'S0[', 'S1[', 'S2[', 'MEM[', 'RETURN_DATA', 'DATA.', 'DATA0', 'DATA1', 'ADDR']) or
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s))
if is_code: result.append(s)
@ -448,13 +467,13 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
# Get op enums for this arch (import from .ins which re-exports from .enum)
import importlib
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp'] if hasattr(autogen, name)]
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp', 'DSOp'] if hasattr(autogen, name)]
# Build defined ops mapping
defined_ops: dict[tuple, list] = {}
for enum_cls in OP_ENUMS:
for op in enum_cls:
if op.name.startswith(('S_', 'V_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
if op.name.startswith(('S_', 'V_', 'DS_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
enum_names = [e.__name__ for e in OP_ENUMS]
lines = [f'''# autogenerated by pdf.py - do not edit
@ -541,11 +560,14 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
is_div_scale = 'DIV_SCALE' in op.name
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
has_opsel = 'OPSEL' in pc # FMA_MIX and similar instructions need OPSEL/OPSEL_HI
combined = code + pc
fn_name = f"_{cls_name}_{op.name}"
# Function accepts Reg objects directly (uppercase names), laneId is passed directly as int
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"]
params = "S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None"
if has_opsel: params += ", OPSEL=0, OPSEL_HI=0"
lines = [f"def {fn_name}({params}):"]
# Registers that need special handling (not passed directly)
# Only init if used but not first assigned as `name = Reg(...)` in the compiled code