mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
master
...
more_pcode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aec4d65241 | ||
|
|
f022a7d8a7 |
4 changed files with 271 additions and 56 deletions
|
|
@ -28,4 +28,4 @@ The ops tests also pass, but they are very slow, so you should run them one at a
|
|||
|
||||
When something is caught by main tinygrad tests, a local regression test should be added to `extra/assembly/amd/test`. While working with tinygrad, you can dump the assembly with `DEBUG=7`. These tests all pass on real hardware, so if a test is failing with `AMD=1 PYTHON_REMU=1 MOCKGPU=1` it's likely because an instruction is emulated incorrectly. You can test without `MOCKGPU=1` to test on real hardware, if it works on real hardware there's a bug in the emulator.
|
||||
|
||||
Currently, only RDNA3 is well supported, but when finished, this will support RDNA3+RDNA4+CDNA in ~2000 lines.
|
||||
Currently, only RDNA3 is well supported, but when finished, this will support RDNA3+RDNA4+CDNA in ~2000 lines. Count lines with `cloc --by-file extra/assembly/amd/*.py`
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# to regenerate: python -m extra.assembly.amd.pdf --arch rdna3
|
||||
# ruff: noqa: E501,F405,F403
|
||||
# mypy: ignore-errors
|
||||
from extra.assembly.amd.autogen.rdna3.enum import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp
|
||||
from extra.assembly.amd.autogen.rdna3.enum import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp
|
||||
from extra.assembly.amd.pcode import *
|
||||
|
||||
def _SOP1Op_S_MOV_B32(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):
|
||||
|
|
@ -5203,6 +5203,102 @@ def _VOP3POp_V_DOT2_F32_BF16(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VG
|
|||
D0.f32 = tmp
|
||||
return {'D0': D0}
|
||||
|
||||
def _VOP3POp_V_FMA_MIX_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0, OPSEL=0, OPSEL_HI=0):
|
||||
# declare in : 32'F[3];
|
||||
# declare S : 32'B[3];
|
||||
# for i in 0 : 2 do
|
||||
# if !OPSEL_HI.u3[i] then
|
||||
# in[i] = S[i].f32
|
||||
# elsif OPSEL.u3[i] then
|
||||
# in[i] = f16_to_f32(S[i][31 : 16].f16)
|
||||
# else
|
||||
# in[i] = f16_to_f32(S[i][15 : 0].f16)
|
||||
# endif
|
||||
# endfor;
|
||||
# D0[31 : 0].f32 = fma(in[0], in[1], in[2])
|
||||
S0 = Reg(s0)
|
||||
S1 = Reg(s1)
|
||||
S2 = Reg(s2)
|
||||
D0 = Reg(d0)
|
||||
# --- compiled pseudocode ---
|
||||
in_ = [Reg(0) for _ in range(3)]
|
||||
S = [S0, S1, S2]
|
||||
for i in range(0, int(2)+1):
|
||||
if not ((OPSEL_HI >> i) & 1):
|
||||
in_[i] = S[i].f32
|
||||
elif ((OPSEL >> i) & 1):
|
||||
in_[i] = f16_to_f32(S[i][31 : 16].f16)
|
||||
else:
|
||||
in_[i] = f16_to_f32(S[i][15 : 0].f16)
|
||||
D0[31 : 0].f32 = fma(in_[0], in_[1], in_[2])
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
|
||||
def _VOP3POp_V_FMA_MIXLO_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0, OPSEL=0, OPSEL_HI=0):
|
||||
# declare in : 32'F[3];
|
||||
# declare S : 32'B[3];
|
||||
# for i in 0 : 2 do
|
||||
# if !OPSEL_HI.u3[i] then
|
||||
# in[i] = S[i].f32
|
||||
# elsif OPSEL.u3[i] then
|
||||
# in[i] = f16_to_f32(S[i][31 : 16].f16)
|
||||
# else
|
||||
# in[i] = f16_to_f32(S[i][15 : 0].f16)
|
||||
# endif
|
||||
# endfor;
|
||||
# D0[15 : 0].f16 = f32_to_f16(fma(in[0], in[1], in[2]))
|
||||
S0 = Reg(s0)
|
||||
S1 = Reg(s1)
|
||||
S2 = Reg(s2)
|
||||
D0 = Reg(d0)
|
||||
# --- compiled pseudocode ---
|
||||
in_ = [Reg(0) for _ in range(3)]
|
||||
S = [S0, S1, S2]
|
||||
for i in range(0, int(2)+1):
|
||||
if not ((OPSEL_HI >> i) & 1):
|
||||
in_[i] = S[i].f32
|
||||
elif ((OPSEL >> i) & 1):
|
||||
in_[i] = f16_to_f32(S[i][31 : 16].f16)
|
||||
else:
|
||||
in_[i] = f16_to_f32(S[i][15 : 0].f16)
|
||||
D0[15 : 0].f16 = f32_to_f16(fma(in_[0], in_[1], in_[2]))
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
|
||||
def _VOP3POp_V_FMA_MIXHI_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0, OPSEL=0, OPSEL_HI=0):
|
||||
# declare in : 32'F[3];
|
||||
# declare S : 32'B[3];
|
||||
# for i in 0 : 2 do
|
||||
# if !OPSEL_HI.u3[i] then
|
||||
# in[i] = S[i].f32
|
||||
# elsif OPSEL.u3[i] then
|
||||
# in[i] = f16_to_f32(S[i][31 : 16].f16)
|
||||
# else
|
||||
# in[i] = f16_to_f32(S[i][15 : 0].f16)
|
||||
# endif
|
||||
# endfor;
|
||||
# D0[31 : 16].f16 = f32_to_f16(fma(in[0], in[1], in[2]))
|
||||
S0 = Reg(s0)
|
||||
S1 = Reg(s1)
|
||||
S2 = Reg(s2)
|
||||
D0 = Reg(d0)
|
||||
# --- compiled pseudocode ---
|
||||
in_ = [Reg(0) for _ in range(3)]
|
||||
S = [S0, S1, S2]
|
||||
for i in range(0, int(2)+1):
|
||||
if not ((OPSEL_HI >> i) & 1):
|
||||
in_[i] = S[i].f32
|
||||
elif ((OPSEL >> i) & 1):
|
||||
in_[i] = f16_to_f32(S[i][31 : 16].f16)
|
||||
else:
|
||||
in_[i] = f16_to_f32(S[i][15 : 0].f16)
|
||||
D0[31 : 16].f16 = f32_to_f16(fma(in_[0], in_[1], in_[2]))
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
|
||||
VOP3POp_FUNCTIONS = {
|
||||
VOP3POp.V_PK_MAD_I16: _VOP3POp_V_PK_MAD_I16,
|
||||
VOP3POp.V_PK_MUL_LO_U16: _VOP3POp_V_PK_MUL_LO_U16,
|
||||
|
|
@ -5227,6 +5323,9 @@ VOP3POp_FUNCTIONS = {
|
|||
VOP3POp.V_DOT4_U32_U8: _VOP3POp_V_DOT4_U32_U8,
|
||||
VOP3POp.V_DOT8_U32_U4: _VOP3POp_V_DOT8_U32_U4,
|
||||
VOP3POp.V_DOT2_F32_BF16: _VOP3POp_V_DOT2_F32_BF16,
|
||||
VOP3POp.V_FMA_MIX_F32: _VOP3POp_V_FMA_MIX_F32,
|
||||
VOP3POp.V_FMA_MIXLO_F16: _VOP3POp_V_FMA_MIXLO_F16,
|
||||
VOP3POp.V_FMA_MIXHI_F16: _VOP3POp_V_FMA_MIXHI_F16,
|
||||
}
|
||||
|
||||
def _VOPCOp_V_CMP_F_F16(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):
|
||||
|
|
@ -6261,6 +6360,119 @@ def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
|
|||
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
|
||||
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# DS (Data Share) INSTRUCTIONS
|
||||
# DS instructions operate on LDS (Local Data Share) memory
|
||||
# They receive: addr (address), data0/data1 (data VGPRs), offset0/offset1 (byte offsets)
|
||||
# They return: {'vdst': [...]} for loads, {'lds_writes': [...]} for stores
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _ds_load(lds, addr, offset, size, sign_extend=False):
|
||||
"""Load from LDS memory. Returns list of 32-bit values."""
|
||||
a = (addr + offset) & 0xffff
|
||||
if size <= 4:
|
||||
val = int.from_bytes(lds[a:a+size], 'little')
|
||||
if sign_extend and size < 4:
|
||||
# Sign extend from size*8 bits to 32 bits
|
||||
sign_bit = 1 << (size * 8 - 1)
|
||||
if val & sign_bit: val |= ~((1 << (size * 8)) - 1)
|
||||
return [val & 0xffffffff]
|
||||
# Multi-dword load
|
||||
return [int.from_bytes(lds[a+i*4:a+i*4+4], 'little') for i in range(size // 4)]
|
||||
|
||||
def _ds_store(lds, addr, offset, values, size):
|
||||
"""Store to LDS memory. values is list of 32-bit dwords, size is bytes per element."""
|
||||
a = (addr + offset) & 0xffff
|
||||
if size <= 4:
|
||||
lds[a:a+size] = (values[0] & ((1 << (size * 8)) - 1)).to_bytes(size, 'little')
|
||||
else:
|
||||
for i, v in enumerate(values):
|
||||
lds[a+i*4:a+i*4+4] = (v & 0xffffffff).to_bytes(4, 'little')
|
||||
|
||||
# Load operations: DS_LOAD_B32, DS_LOAD_B64, DS_LOAD_B128, DS_LOAD_U8, DS_LOAD_I8, DS_LOAD_U16, DS_LOAD_I16
|
||||
def _DSOp_DS_LOAD_B32(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
return {'vdst': _ds_load(lds, addr, offset0, 4)}
|
||||
|
||||
def _DSOp_DS_LOAD_B64(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
return {'vdst': _ds_load(lds, addr, offset0, 8)}
|
||||
|
||||
def _DSOp_DS_LOAD_B128(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
return {'vdst': _ds_load(lds, addr, offset0, 16)}
|
||||
|
||||
def _DSOp_DS_LOAD_U8(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
return {'vdst': _ds_load(lds, addr, offset0, 1, sign_extend=False)}
|
||||
|
||||
def _DSOp_DS_LOAD_I8(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
return {'vdst': _ds_load(lds, addr, offset0, 1, sign_extend=True)}
|
||||
|
||||
def _DSOp_DS_LOAD_U16(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
return {'vdst': _ds_load(lds, addr, offset0, 2, sign_extend=False)}
|
||||
|
||||
def _DSOp_DS_LOAD_I16(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
return {'vdst': _ds_load(lds, addr, offset0, 2, sign_extend=True)}
|
||||
|
||||
# Store operations: DS_STORE_B32, DS_STORE_B64, DS_STORE_B128, DS_STORE_B8, DS_STORE_B16
|
||||
def _DSOp_DS_STORE_B32(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
_ds_store(lds, addr, offset0, data0, 4)
|
||||
return {}
|
||||
|
||||
def _DSOp_DS_STORE_B64(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
_ds_store(lds, addr, offset0, data0, 8)
|
||||
return {}
|
||||
|
||||
def _DSOp_DS_STORE_B128(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
_ds_store(lds, addr, offset0, data0, 16)
|
||||
return {}
|
||||
|
||||
def _DSOp_DS_STORE_B8(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
_ds_store(lds, addr, offset0, data0, 1)
|
||||
return {}
|
||||
|
||||
def _DSOp_DS_STORE_B16(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
_ds_store(lds, addr, offset0, data0, 2)
|
||||
return {}
|
||||
|
||||
# 2-address operations: DS_LOAD_2ADDR_B32, DS_LOAD_2ADDR_B64, DS_STORE_2ADDR_B32, DS_STORE_2ADDR_B64
|
||||
# Note: offsets are scaled by data size (4 for B32, 8 for B64)
|
||||
def _DSOp_DS_LOAD_2ADDR_B32(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
v0 = _ds_load(lds, addr, offset0 * 4, 4)
|
||||
v1 = _ds_load(lds, addr, offset1 * 4, 4)
|
||||
return {'vdst': v0 + v1}
|
||||
|
||||
def _DSOp_DS_LOAD_2ADDR_B64(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
v0 = _ds_load(lds, addr, offset0 * 8, 8)
|
||||
v1 = _ds_load(lds, addr, offset1 * 8, 8)
|
||||
return {'vdst': v0 + v1}
|
||||
|
||||
def _DSOp_DS_STORE_2ADDR_B32(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
_ds_store(lds, addr, offset0 * 4, data0, 4)
|
||||
_ds_store(lds, addr, offset1 * 4, data1, 4)
|
||||
return {}
|
||||
|
||||
def _DSOp_DS_STORE_2ADDR_B64(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
_ds_store(lds, addr, offset0 * 8, data0, 8)
|
||||
_ds_store(lds, addr, offset1 * 8, data1, 8)
|
||||
return {}
|
||||
|
||||
DSOp_FUNCTIONS = {
|
||||
DSOp.DS_LOAD_B32: _DSOp_DS_LOAD_B32,
|
||||
DSOp.DS_LOAD_B64: _DSOp_DS_LOAD_B64,
|
||||
DSOp.DS_LOAD_B128: _DSOp_DS_LOAD_B128,
|
||||
DSOp.DS_LOAD_U8: _DSOp_DS_LOAD_U8,
|
||||
DSOp.DS_LOAD_I8: _DSOp_DS_LOAD_I8,
|
||||
DSOp.DS_LOAD_U16: _DSOp_DS_LOAD_U16,
|
||||
DSOp.DS_LOAD_I16: _DSOp_DS_LOAD_I16,
|
||||
DSOp.DS_STORE_B32: _DSOp_DS_STORE_B32,
|
||||
DSOp.DS_STORE_B64: _DSOp_DS_STORE_B64,
|
||||
DSOp.DS_STORE_B128: _DSOp_DS_STORE_B128,
|
||||
DSOp.DS_STORE_B8: _DSOp_DS_STORE_B8,
|
||||
DSOp.DS_STORE_B16: _DSOp_DS_STORE_B16,
|
||||
DSOp.DS_LOAD_2ADDR_B32: _DSOp_DS_LOAD_2ADDR_B32,
|
||||
DSOp.DS_LOAD_2ADDR_B64: _DSOp_DS_LOAD_2ADDR_B64,
|
||||
DSOp.DS_STORE_2ADDR_B32: _DSOp_DS_STORE_2ADDR_B32,
|
||||
DSOp.DS_STORE_2ADDR_B64: _DSOp_DS_STORE_2ADDR_B64,
|
||||
}
|
||||
|
||||
COMPILED_FUNCTIONS = {
|
||||
SOP1Op: SOP1Op_FUNCTIONS,
|
||||
SOP2Op: SOP2Op_FUNCTIONS,
|
||||
|
|
@ -6273,6 +6485,7 @@ COMPILED_FUNCTIONS = {
|
|||
VOP3SDOp: VOP3SDOp_FUNCTIONS,
|
||||
VOP3POp: VOP3POp_FUNCTIONS,
|
||||
VOPCOp: VOPCOp_FUNCTIONS,
|
||||
DSOp: DSOp_FUNCTIONS,
|
||||
}
|
||||
|
||||
def get_compiled_functions(): return COMPILED_FUNCTIONS
|
||||
|
|
@ -7,7 +7,7 @@ from extra.assembly.amd.pcode import Reg
|
|||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
|
||||
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, VOPDOp)
|
||||
SrcEnum, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, GLOBALOp, FLATOp, DSOp, VOPDOp)
|
||||
|
||||
Program = dict[int, Inst]
|
||||
WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256
|
||||
|
|
@ -51,11 +51,6 @@ _D16_LOAD_MAP = {'LOAD_D16_U8': (1,0,0), 'LOAD_D16_I8': (1,1,0), 'LOAD_D16_B16':
|
|||
_D16_STORE_MAP = {'STORE_D16_HI_B8': (1,1), 'STORE_D16_HI_B16': (2,1)} # (size, hi)
|
||||
FLAT_D16_LOAD = _mem_ops([GLOBALOp, FLATOp], _D16_LOAD_MAP)
|
||||
FLAT_D16_STORE = _mem_ops([GLOBALOp, FLATOp], _D16_STORE_MAP)
|
||||
DS_LOAD = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0), DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)}
|
||||
DS_STORE = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)}
|
||||
# 2ADDR ops: load/store two values using offset0 and offset1
|
||||
DS_LOAD_2ADDR = {DSOp.DS_LOAD_2ADDR_B32: 4, DSOp.DS_LOAD_2ADDR_B64: 8}
|
||||
DS_STORE_2ADDR = {DSOp.DS_STORE_2ADDR_B32: 4, DSOp.DS_STORE_2ADDR_B64: 8}
|
||||
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
|
||||
|
||||
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
|
||||
|
|
@ -225,32 +220,15 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
|||
return
|
||||
|
||||
if isinstance(inst, DS):
|
||||
op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst
|
||||
if op in DS_LOAD:
|
||||
cnt, sz, sign = DS_LOAD[op]
|
||||
for i in range(cnt): val = int.from_bytes(lds[addr0+i*sz:addr0+i*sz+sz], 'little'); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
|
||||
elif op in DS_STORE:
|
||||
cnt, sz = DS_STORE[op]
|
||||
for i in range(cnt): lds[addr0+i*sz:addr0+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little')
|
||||
elif op in DS_LOAD_2ADDR:
|
||||
# Load two values from addr+offset0*sz and addr+offset1*sz into vdst (B32: 1 dword each, B64: 2 dwords each)
|
||||
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
|
||||
sz = DS_LOAD_2ADDR[op]
|
||||
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
|
||||
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
|
||||
cnt = sz // 4 # 1 for B32, 2 for B64
|
||||
for i in range(cnt): V[vdst + i] = int.from_bytes(lds[addr0+i*4:addr0+i*4+4], 'little')
|
||||
for i in range(cnt): V[vdst + cnt + i] = int.from_bytes(lds[addr1+i*4:addr1+i*4+4], 'little')
|
||||
elif op in DS_STORE_2ADDR:
|
||||
# Store two values from data0 and data1 to addr+offset0*sz and addr+offset1*sz
|
||||
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
|
||||
sz = DS_STORE_2ADDR[op]
|
||||
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
|
||||
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
|
||||
cnt = sz // 4
|
||||
for i in range(cnt): lds[addr0+i*4:addr0+i*4+4] = (V[inst.data0 + i] & MASK32).to_bytes(4, 'little')
|
||||
for i in range(cnt): lds[addr1+i*4:addr1+i*4+4] = (V[inst.data1 + i] & MASK32).to_bytes(4, 'little')
|
||||
else: raise NotImplementedError(f"DS op {op}")
|
||||
fn = compiled.get(DSOp, {}).get(inst.op)
|
||||
if fn is None: raise NotImplementedError(f"DS op {inst.op.name} not in pseudocode")
|
||||
# Prepare data registers as lists of dwords
|
||||
data0 = [V[inst.data0 + i] for i in range(4)] # up to 4 dwords
|
||||
data1 = [V[inst.data1 + i] for i in range(4)] if inst.data1 else [0, 0, 0, 0]
|
||||
result = fn(lds, V[inst.addr], data0, data1, inst.vdst, inst.offset0, inst.offset1)
|
||||
# Write results for loads
|
||||
if 'vdst' in result:
|
||||
for i, val in enumerate(result['vdst']): V[inst.vdst + i] = val & MASK32
|
||||
return
|
||||
|
||||
# VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)
|
||||
|
|
@ -306,17 +284,19 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
|||
exec_wmma(st, inst, inst.op)
|
||||
return
|
||||
# V_FMA_MIX: Mixed precision FMA - opsel_hi controls f32(0) vs f16(1), opsel selects which f16 half
|
||||
if 'FMA_MIX' in inst.op_name:
|
||||
# Handle inline because abs/neg must be applied AFTER type conversion
|
||||
if inst.op in (VOP3POp.V_FMA_MIX_F32, VOP3POp.V_FMA_MIXLO_F16, VOP3POp.V_FMA_MIXHI_F16):
|
||||
opsel, opsel_hi, opsel_hi2 = getattr(inst, 'opsel', 0), getattr(inst, 'opsel_hi', 0), getattr(inst, 'opsel_hi2', 0)
|
||||
neg, abs_ = getattr(inst, 'neg', 0), getattr(inst, 'neg_hi', 0) # neg_hi reused as abs
|
||||
neg, abs_ = getattr(inst, 'neg', 0), getattr(inst, 'neg_hi', 0) # neg_hi reused as abs for FMA_MIX
|
||||
raws = [st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane) if inst.src2 is not None else 0]
|
||||
is_f16 = [opsel_hi & 1, opsel_hi & 2, opsel_hi2]
|
||||
srcs = [_f16(_src16(raws[i], bool(opsel & (1<<i)))) if is_f16[i] else _f32(raws[i]) for i in range(3)]
|
||||
for i in range(3):
|
||||
if abs_ & (1<<i): srcs[i] = abs(srcs[i])
|
||||
if neg & (1<<i): srcs[i] = -srcs[i]
|
||||
result = srcs[0] * srcs[1] + srcs[2]
|
||||
st.vgpr[lane][inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
|
||||
result_f = srcs[0] * srcs[1] + srcs[2]
|
||||
V = st.vgpr[lane]
|
||||
V[inst.vdst] = _i32(result_f) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result_f), inst.op == VOP3POp.V_FMA_MIXHI_F16)
|
||||
return
|
||||
# VOP3P packed ops: opsel selects halves for lo, opsel_hi for hi; neg toggles f16 sign
|
||||
raws = [st.rsrc_f16(inst.src0, lane), st.rsrc_f16(inst.src1, lane), st.rsrc_f16(inst.src2, lane) if inst.src2 is not None else 0]
|
||||
|
|
|
|||
|
|
@ -36,13 +36,13 @@ FIELD_ORDER = {
|
|||
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
|
||||
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
|
||||
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
|
||||
INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||
INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||
|
||||
# Patterns that can't be handled by the DSL (require special handling in emu.py)
|
||||
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
|
||||
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
|
||||
'CVT_OFF_TABLE', 'ThreadMask',
|
||||
'S1[i', 'C.i32', 'S[i]', 'in[',
|
||||
'S1[i', 'C.i32',
|
||||
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST',
|
||||
'BARRIER_STATE', 'ReallocVgprs',
|
||||
'GPR_IDX', 'VSKIP', 'specified in', 'TTBL',
|
||||
|
|
@ -67,17 +67,18 @@ def compile_pseudocode(pseudocode: str) -> str:
|
|||
|
||||
lines = []
|
||||
indent, need_pass, in_first_match_loop = 0, False, False
|
||||
declared_arrays: dict[str, int] = {} # Track declared arrays: name -> size
|
||||
for line in joined_lines:
|
||||
line = line.strip()
|
||||
if not line or line.startswith('//'): continue
|
||||
if line.startswith('if '):
|
||||
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
|
||||
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'), declared_arrays)}:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line.startswith('elsif '):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:")
|
||||
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'), declared_arrays)}:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line == 'else':
|
||||
|
|
@ -94,10 +95,19 @@ def compile_pseudocode(pseudocode: str) -> str:
|
|||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
need_pass, in_first_match_loop = False, False
|
||||
elif m := re.match(r'declare\s+(\w+)\s*:\s*\d+\'[FBU]\[(\d+)\]', line):
|
||||
# Handle array declarations: declare in : 32'F[3] or declare S : 32'B[3]
|
||||
arr_name, arr_size = m[1], int(m[2])
|
||||
declared_arrays[arr_name] = arr_size
|
||||
py_name = f"{arr_name}_" if arr_name == 'in' else arr_name # 'in' is Python keyword
|
||||
if arr_name == 'S':
|
||||
lines.append(' ' * indent + f"{py_name} = [S0, S1, S2]") # Map to source registers
|
||||
else:
|
||||
lines.append(' ' * indent + f"{py_name} = [Reg(0) for _ in range({arr_size})]")
|
||||
elif line.startswith('declare '):
|
||||
pass
|
||||
pass # Ignore other declare statements
|
||||
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
|
||||
start, end = _expr(m[2].strip()), _expr(m[3].strip())
|
||||
start, end = _expr(m[2].strip(), declared_arrays), _expr(m[3].strip(), declared_arrays)
|
||||
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
|
||||
indent += 1
|
||||
need_pass, in_first_match_loop = True, True
|
||||
|
|
@ -105,7 +115,7 @@ def compile_pseudocode(pseudocode: str) -> str:
|
|||
need_pass = False
|
||||
line = line.rstrip(';')
|
||||
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
|
||||
rhs = _expr(m[1])
|
||||
rhs = _expr(m[1], declared_arrays)
|
||||
lines.append(' ' * indent + f"_full = {rhs}")
|
||||
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
|
||||
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
|
||||
|
|
@ -113,30 +123,39 @@ def compile_pseudocode(pseudocode: str) -> str:
|
|||
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
|
||||
if op in line:
|
||||
lhs, rhs = line.split(op, 1)
|
||||
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}")
|
||||
lhs_s = _expr(lhs.strip(), declared_arrays) # Transform LHS too for array access
|
||||
lines.append(' ' * indent + f"{lhs_s} {op} {_expr(rhs.strip(), declared_arrays)}")
|
||||
break
|
||||
else:
|
||||
lhs, rhs = line.split('=', 1)
|
||||
lhs_s, rhs_s = _expr(lhs.strip()), rhs.strip()
|
||||
stmt = _assign(lhs_s, _expr(rhs_s))
|
||||
lhs_s, rhs_s = lhs.strip(), rhs.strip()
|
||||
lhs_t = _expr(lhs_s, declared_arrays) # Transform LHS for array access
|
||||
stmt = _assign(lhs_t, _expr(rhs_s, declared_arrays), declared_arrays)
|
||||
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
|
||||
stmt += "; break"
|
||||
lines.append(' ' * indent + stmt)
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _assign(lhs: str, rhs: str) -> str:
|
||||
def _assign(lhs: str, rhs: str, declared_arrays: dict[str, int] | None = None) -> str:
|
||||
# Check for array element assignment: in_[i] should not wrap in Reg()
|
||||
if declared_arrays and re.match(r'\w+_?\[\w+\]', lhs):
|
||||
return f"{lhs} = {rhs}"
|
||||
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
|
||||
return f"{lhs} = Reg({rhs})"
|
||||
return f"{lhs} = {rhs}"
|
||||
|
||||
def _expr(e: str) -> str:
|
||||
def _expr(e: str, declared_arrays: dict[str, int] | None = None) -> str:
|
||||
e = e.strip()
|
||||
# Handle OPSEL_HI.u3[i] and OPSEL.u3[i] - bit extraction from opsel fields
|
||||
e = re.sub(r'(OPSEL(?:_HI)?)\.u\d+\[(\w+)\]', r'((\1 >> \2) & 1)', e)
|
||||
# Rename 'in' to 'in_' to avoid Python keyword conflict
|
||||
e = re.sub(r'\bin\[', 'in_[', e)
|
||||
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
|
||||
e = re.sub(r'!([^=])', r' not \1', e)
|
||||
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
|
||||
def pack(m):
|
||||
hi, lo = _expr(m[1].strip()), _expr(m[2].strip())
|
||||
hi, lo = _expr(m[1].strip(), declared_arrays), _expr(m[2].strip(), declared_arrays)
|
||||
return f'_pack({hi}, {lo})'
|
||||
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
|
||||
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
|
||||
|
|
@ -164,7 +183,7 @@ def _expr(e: str) -> str:
|
|||
if s[j] == '[': depth += 1
|
||||
elif s[j] == ']': depth -= 1
|
||||
j += 1
|
||||
inner = _expr(s[start:j-1])
|
||||
inner = _expr(s[start:j-1], declared_arrays)
|
||||
result.append('[' + inner + ']')
|
||||
i = j
|
||||
else:
|
||||
|
|
@ -362,7 +381,7 @@ def _extract_pseudocode(text: str) -> str | None:
|
|||
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
|
||||
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
|
||||
is_code = (any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =',
|
||||
'D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
|
||||
'D0[', 'D1[', 'S0[', 'S1[', 'S2[', 'MEM[', 'RETURN_DATA', 'DATA.', 'DATA0', 'DATA1', 'ADDR']) or
|
||||
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
|
||||
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s))
|
||||
if is_code: result.append(s)
|
||||
|
|
@ -448,13 +467,13 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
|
|||
# Get op enums for this arch (import from .ins which re-exports from .enum)
|
||||
import importlib
|
||||
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
|
||||
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp'] if hasattr(autogen, name)]
|
||||
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp', 'DSOp'] if hasattr(autogen, name)]
|
||||
|
||||
# Build defined ops mapping
|
||||
defined_ops: dict[tuple, list] = {}
|
||||
for enum_cls in OP_ENUMS:
|
||||
for op in enum_cls:
|
||||
if op.name.startswith(('S_', 'V_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
|
||||
if op.name.startswith(('S_', 'V_', 'DS_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
|
||||
|
||||
enum_names = [e.__name__ for e in OP_ENUMS]
|
||||
lines = [f'''# autogenerated by pdf.py - do not edit
|
||||
|
|
@ -541,11 +560,14 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
|
|||
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
|
||||
is_div_scale = 'DIV_SCALE' in op.name
|
||||
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
||||
has_opsel = 'OPSEL' in pc # FMA_MIX and similar instructions need OPSEL/OPSEL_HI
|
||||
combined = code + pc
|
||||
|
||||
fn_name = f"_{cls_name}_{op.name}"
|
||||
# Function accepts Reg objects directly (uppercase names), laneId is passed directly as int
|
||||
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"]
|
||||
params = "S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None"
|
||||
if has_opsel: params += ", OPSEL=0, OPSEL_HI=0"
|
||||
lines = [f"def {fn_name}({params}):"]
|
||||
|
||||
# Registers that need special handling (not passed directly)
|
||||
# Only init if used but not first assigned as `name = Reg(...)` in the compiled code
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue