amd asm emulator fixes + run it in CI (#14786)

* amd asm fix, try 2

* fix tests
This commit is contained in:
George Hotz 2026-02-16 13:24:21 +08:00 committed by GitHub
commit dff9cf35c2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 283 additions and 97 deletions

View file

@ -659,6 +659,8 @@ jobs:
run: |
PYTHONPATH=. NULL=1 EMULATE=AMD python extra/mmapeak/mmapeak.py
PYTHONPATH=. NULL=1 EMULATE=AMD_CDNA4 python3 -m pytest -n=auto test/testextra/test_tk.py test/testextra/test_asm_gemm.py
- name: Run ASM matmul on MOCKGPU
run: PYTHONPATH="." AMD=1 MOCKGPU=1 N=256 python3 extra/gemm/amd_asm_matmul.py
- name: Run LLVM test
run: AMD_LLVM=1 python test/device/test_amd_llvm.py

View file

@ -1,5 +1,5 @@
# RDNA3 128x128 tiled GEMM kernel - DSL version
# Computes C = A @ B for 4096x4096 float32 matrices using 128x128 tiles
# Computes C = A @ B for NxN float32 matrices using 128x128 tiles
#
# Architecture: RDNA3 (gfx1100)
# Tile size: 128x128 (each workgroup computes one tile of C)
@ -21,7 +21,6 @@ from tinygrad.runtime.autogen.amd.rdna3.ins import *
# Kernel constants
# =============================================================================
LDS_SIZE = 8320 # Local data share size in bytes
MATRIX_DIM = 4096 # Matrix dimension N (assumes square NxN matrices)
LDS_A_STRIDE = 0x210 # LDS stride for A tile (528 bytes)
LDS_B_STRIDE = 0x200 # LDS stride for B tile (512 bytes)
LDS_BASE_OFFSET = 0x1080 # Base LDS offset for tiles
@ -62,7 +61,7 @@ S_TILE_Y = 15 # workgroup_y << 7
# Kernarg load destinations
S_KERNARG_A = (20, 21) # A pointer from kernarg
S_KERNARG_B = (22, 23) # B pointer from kernarg
# Prefetch base pointers (8 pairs each, 16KB/256KB apart)
# Prefetch base pointers (8 pairs each, B: N*4 bytes apart, A: N*64 bytes apart)
S_PREFETCH_B = 24 # s[24:39] - 8 B tile pointers
S_PREFETCH_A = 40 # s[40:55] - 8 A tile pointers
@ -197,7 +196,9 @@ class Kernel:
# Kernel builder
# =============================================================================
def build_kernel(arch='gfx1100'):
def build_kernel(N, arch='gfx1100'):
assert N % 128 == 0, f"N must be a multiple of 128 (tile size), got {N}"
assert N >= 256, f"N must be >= 256 (prefetch pipeline requires at least 2 K-blocks), got {N}"
k = Kernel(arch)
# ===========================================================================
@ -205,7 +206,7 @@ def build_kernel(arch='gfx1100'):
# ===========================================================================
k.emit(s_load_b128(sdata=s[S_KERNARG_A[0]:S_KERNARG_B[1]], sbase=s[0:1], offset=0x0, soffset=NULL))
k.emit(s_load_b64(sdata=s[S_OUT_PTR[0]:S_OUT_PTR[1]], sbase=s[0:1], offset=0x10, soffset=NULL))
k.emit(s_mov_b32(s[S_DIM_N], MATRIX_DIM))
k.emit(s_mov_b32(s[S_DIM_N], N))
k.emit(s_mov_b32(s[S_LOOP_CTR], 0)) # used by LDS swizzle, always 0 for valid workgroups
k.emit(s_lshl_b32(s[S_TILE_X], s[S_WORKGROUP_X], 7))
k.emit(s_lshl_b32(s[S_TILE_Y], s[S_WORKGROUP_Y], 7))
@ -220,19 +221,20 @@ def build_kernel(arch='gfx1100'):
# Compute 8 A and B matrix tile base pointers for prefetch
k.emit(s_mov_b64(s[S_PREFETCH_B:S_PREFETCH_B+1], s[S_KERNARG_B[0]:S_KERNARG_B[1]])) # B[0]: no offset
for i in range(1, 8): # B: 16KB apart
k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_KERNARG_B[0]], i * 0x4000))
for i in range(1, 8): # B: each pointer 1 row of B apart (N*4 bytes)
k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_KERNARG_B[0]], i * N * 4))
k.emit(s_addc_u32(s[S_PREFETCH_B+i*2+1], s[S_KERNARG_B[1]], 0))
k.emit(s_mov_b64(s[S_PREFETCH_A:S_PREFETCH_A+1], s[S_KERNARG_A[0]:S_KERNARG_A[1]])) # A[0]: no offset
for i in range(1, 8): # A: 256KB apart
k.emit(s_add_u32(s[S_PREFETCH_A+i*2], s[S_KERNARG_A[0]], i * 0x40000))
for i in range(1, 8): # A: each pointer 16 rows of A apart (16*N*4 bytes)
k.emit(s_add_u32(s[S_PREFETCH_A+i*2], s[S_KERNARG_A[0]], i * N * 64))
k.emit(s_addc_u32(s[S_PREFETCH_A+i*2+1], s[S_KERNARG_A[1]], 0))
# Global prefetch addresses: B = (tile_x + lane_id) * 4, A = ((tile_y << 12) + (lane_id/8)*4K + lane_id%8) * 4
# Global prefetch addresses: B = (tile_x + lane_id) * 4, A = (tile_y*N + (lane_id/8)*N + lane_id%8) * 4
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], s[S_TILE_X], v[V_LANE_ID]))
k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_B_ADDR], 2, v[V_GLOBAL_B_ADDR]))
k.emit(s_lshl_b32(s[19], s[S_TILE_Y], 12))
k.emit(v_lshl_add_u32(v[V_GLOBAL_A_ADDR], v[4], 12, v[V_LANE_ID_MOD8])) # (lane_id/8)*4K + lane_id%8
k.emit(s_mul_i32(s[19], s[S_TILE_Y], N))
k.emit(v_mul_lo_u32(v[V_GLOBAL_A_ADDR], v[4], N)) # (lane_id/8)*N
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], v[V_LANE_ID_MOD8], v[V_GLOBAL_A_ADDR])) # + lane_id%8
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], s[19], v[V_GLOBAL_A_ADDR]))
k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_A_ADDR], 2, v[V_GLOBAL_A_ADDR]))
@ -303,13 +305,13 @@ def build_kernel(arch='gfx1100'):
if not NO_GLOBAL:
# Advance prefetch pointers (VGPR)
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], 0x20000, v[V_GLOBAL_B_ADDR]))
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], N * 32, v[V_GLOBAL_B_ADDR]))
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR]))
# Advance prefetch pointers (64-bit adds)
# Advance prefetch pointers (64-bit adds): B advances 8 rows (8*N*4 bytes), A advances 8 cols (8*4 bytes)
k.emit(s_clause(simm16=31))
for i in range(8):
k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_PREFETCH_B+i*2], 0x20000))
k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_PREFETCH_B+i*2], N * 32))
k.emit(s_addc_u32(s[S_PREFETCH_B+i*2+1], s[S_PREFETCH_B+i*2+1], 0))
for i in range(8):
k.emit(s_add_u32(s[S_PREFETCH_A+i*2], s[S_PREFETCH_A+i*2], 0x20))
@ -440,7 +442,7 @@ def test_matmul():
dev = Device[Device.DEFAULT]
print(f"Device arch: {dev.renderer.arch}")
insts = build_kernel(dev.renderer.arch)
insts = build_kernel(N, dev.renderer.arch)
rng = np.random.default_rng(42)
a = Tensor(rng.random((N, N), dtype=np.float32) - 0.5)
@ -472,33 +474,23 @@ def test_matmul():
with Context(DEBUG=2): tc = (a @ b).realize()
with Context(DEBUG=0): err = (c - tc).square().mean().item()
print(f"mean squared error {err}")
if err != err or err > 1e-06: raise RuntimeError("matmul is wrong!")
def run_sqtt():
"""Run with SQTT profiling and write trace files."""
import subprocess, os
# Run test_matmul in a subprocess with SQTT enabled from the start (no verify)
env = {**os.environ, "AMD": "1", "SQTT": "1", "CNT": "1", "PROFILE": "1", "PYTHONPATH": ".", "VERIFY": "0"}
result = subprocess.run(
["python", "-c", "from extra.gemm.amd_asm_matmul import test_matmul; test_matmul()"],
capture_output=True, text=True, env=env, timeout=120
)
print(result.stdout)
# Run roc.py to extract trace data
result = subprocess.run(
["python", "extra/sqtt/roc.py", "--profile", "/tmp/profile.pkl.tiny", "--kernel", "kernel"],
capture_output=True, text=True, env={**os.environ, "DEBUG": "5"}, timeout=60
)
output = result.stdout + result.stderr
# Write full output to trace file
with open("/tmp/sqtt_trace.txt", "w") as f:
f.write(output)
print(f"Wrote {len(output)} bytes to /tmp/sqtt_trace.txt")
if err != err or err > 1e-06:
c_np, tc_np = c.numpy(), tc.numpy()
for bi in range(N // 128):
for bj in range(N // 128):
blk_c = c_np[bi*128:(bi+1)*128, bj*128:(bj+1)*128]
blk_ref = tc_np[bi*128:(bi+1)*128, bj*128:(bj+1)*128]
blk_diff = blk_c - blk_ref
zero_rows = [i for i in range(128) if np.all(np.abs(blk_c[i,:]) < 1e-10)]
nz_rows = [i for i in range(128) if i not in zero_rows]
nz_mse = float(np.mean(blk_diff[nz_rows,:]**2)) if nz_rows else 0
print(f"Block ({bi},{bj}): zero_rows={zero_rows}, nz_rows_mse={nz_mse:.2e}")
# show first few non-zero row comparisons
if nz_rows and nz_mse > 1e-6:
for r in nz_rows[:3]:
print(f" row {r} asm[0:8]: {blk_c[r,:8]}")
print(f" row {r} ref[0:8]: {blk_ref[r,:8]}")
raise RuntimeError("matmul is wrong!")
if __name__ == "__main__":
if getenv("ASM", 0): print("\n".join(str(inst) for inst in build_kernel(Device[Device.DEFAULT].renderer.arch)))
elif getenv("SQTT", 0): run_sqtt()
else: test_matmul()
test_matmul()

View file

@ -760,5 +760,110 @@ class TestDsPermute(unittest.TestCase):
self.assertEqual(st.vgpr[0][2], 0x11111111)
class TestDSLargeOffset(unittest.TestCase):
"""Tests for DS instructions with offsets > 255 (offset1 > 0).
The DS offset is a 16-bit value encoded as (offset1 << 8) | offset0.
These tests verify that offset1 is used correctly, not just offset0.
"""
def test_ds_store_load_b32_offset_256(self):
"""DS_STORE_B32/DS_LOAD_B32 with offset=256 (offset0=0, offset1=1)."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0xDEADBEEF),
v_mov_b32_e32(v[0], s[0]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0, offset1=1), # offset = 256
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[1], offset0=0, offset1=1), # offset = 256
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][1], 0xDEADBEEF)
def test_ds_store_load_b32_offset_300(self):
"""DS_STORE_B32/DS_LOAD_B32 with offset=300 (offset0=44, offset1=1)."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0xCAFEBABE),
v_mov_b32_e32(v[0], s[0]),
ds_store_b32(addr=v[10], data0=v[0], offset0=44, offset1=1), # offset = 300
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[1], offset0=44, offset1=1), # offset = 300
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][1], 0xCAFEBABE)
def test_ds_store_load_b64_offset_512(self):
"""DS_STORE_B64/DS_LOAD_B64 with offset=512 (offset0=0, offset1=2)."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0x11111111),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[0], 0x22222222),
v_mov_b32_e32(v[1], s[0]),
ds_store_b64(addr=v[10], data0=v[0:1], offset0=0, offset1=2), # offset = 512
s_waitcnt(lgkmcnt=0),
ds_load_b64(addr=v[10], vdst=v[2:3], offset0=0, offset1=2), # offset = 512
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0x11111111)
self.assertEqual(st.vgpr[0][3], 0x22222222)
def test_ds_large_offset_distinct_from_small(self):
"""Verify offset=256 and offset=0 address different LDS locations."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0xAAAAAAAA),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[0], 0xBBBBBBBB),
v_mov_b32_e32(v[1], s[0]),
# Store 0xAAAAAAAA at offset=0, 0xBBBBBBBB at offset=256
ds_store_b32(addr=v[10], data0=v[0], offset0=0, offset1=0), # offset = 0
ds_store_b32(addr=v[10], data0=v[1], offset0=0, offset1=1), # offset = 256
s_waitcnt(lgkmcnt=0),
# Read back both
ds_load_b32(addr=v[10], vdst=v[2], offset0=0, offset1=0), # offset = 0
ds_load_b32(addr=v[10], vdst=v[3], offset0=0, offset1=1), # offset = 256
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0xAAAAAAAA, "offset=0 should read 0xAAAAAAAA")
self.assertEqual(st.vgpr[0][3], 0xBBBBBBBB, "offset=256 should read 0xBBBBBBBB")
def test_ds_store_load_b32_offset_448(self):
"""DS_STORE_B32/DS_LOAD_B32 with offset=448 (offset0=192, offset1=1) - matches matmul B tile."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0x12345678),
v_mov_b32_e32(v[0], s[0]),
ds_store_b32(addr=v[10], data0=v[0], offset0=192, offset1=1), # offset = 448
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[1], offset0=192, offset1=1), # offset = 448
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][1], 0x12345678)
def test_ds_load_b64_offset_392(self):
"""DS_LOAD_B64 with offset=392 (offset0=136, offset1=1) - matches matmul B tile load."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0xAABBCCDD),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[0], 0x11223344),
v_mov_b32_e32(v[1], s[0]),
ds_store_b64(addr=v[10], data0=v[0:1], offset0=136, offset1=1), # offset = 392
s_waitcnt(lgkmcnt=0),
ds_load_b64(addr=v[10], vdst=v[2:3], offset0=136, offset1=1), # offset = 392
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0xAABBCCDD)
self.assertEqual(st.vgpr[0][3], 0x11223344)
if __name__ == '__main__':
unittest.main()

View file

@ -1579,5 +1579,55 @@ class TestPermlane64(unittest.TestCase):
self.assertEqual(st.vgpr[0][1], 0x12345678)
class TestSwap(unittest.TestCase):
"""Tests for V_SWAP_B32 - swap two VGPRs."""
def test_v_swap_b32_basic(self):
"""V_SWAP_B32 swaps two VGPR values."""
instructions = [
v_mov_b32_e32(v[0], 42),
v_mov_b32_e32(v[1], 99),
v_swap_b32_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 99)
self.assertEqual(st.vgpr[0][1], 42)
def test_v_swap_b32_same_reg(self):
"""V_SWAP_B32 with same src and dst is a no-op."""
instructions = [
v_mov_b32_e32(v[0], 0xDEADBEEF),
v_swap_b32_e32(v[0], v[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 0xDEADBEEF)
def test_v_swap_b32_multi_lane(self):
"""V_SWAP_B32 swaps per-lane values independently."""
instructions = [
# v[0] = lane_id * 10, v[1] = lane_id * 100
v_lshlrev_b32_e32(v[0], 1, v[255]), # v[0] = lane_id * 2
v_add_nc_u32_e32(v[0], v[0], v[255]), # v[0] = lane_id * 3
v_mul_u32_u24_e32(v[1], 100, v[255]), # v[1] = lane_id * 100
v_swap_b32_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=4)
for lane in range(4):
self.assertEqual(st.vgpr[lane][0], lane * 100)
self.assertEqual(st.vgpr[lane][1], lane * 3)
def test_v_swap_b32_chain(self):
"""Two swaps in sequence restore original values."""
instructions = [
v_mov_b32_e32(v[0], 0xAAAAAAAA),
v_mov_b32_e32(v[1], 0x55555555),
v_swap_b32_e32(v[0], v[1]),
v_swap_b32_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 0xAAAAAAAA)
self.assertEqual(st.vgpr[0][1], 0x55555555)
if __name__ == '__main__':
unittest.main()

View file

@ -113,7 +113,7 @@ class PythonEmulator:
def _ensure_decoded(self, pc: int):
if pc not in self.program:
runner = _decode_at(pc, "rdna3")
runner, _ = _decode_at(pc, "rdna3")
self.program[pc] = (runner.p.function_name, runner._prg.fxn, runner.p.globals)
def step(self) -> int:

View file

@ -116,6 +116,8 @@ WAVE_SIZE = 32
PC_LO_IDX, PC_HI_IDX, SCRATCH_STRIDE_IDX = 256, 257, 259
# SGPR buffer: 0-127 = SGPRs, 128-255 = inline constants, 256-259 = special registers
SGPR_COUNT, VGPR_SIZE = 260, 256 * 32
# Sentinel PC value for s_endpgm
ENDPGM_PC = 0xFFFFFFFFFFFFFFFF
def _op_name(inst) -> str:
if hasattr(inst, 'opx'): return f"{inst.opx.name}_{inst.opy.name}" # VOPD has opx/opy not op
@ -210,7 +212,7 @@ def parse_pcode(pcode: str, srcs: dict[str, UOp] | None = None) -> tuple[dict, l
_, final, _ = parse_block(lines, 0, env, assigns=assigns)
sliced = set(d.split('[')[0] for d, _ in assigns if '[' in d)
for var, val in final.items():
if var in ['D0', 'SCC', 'VCC', 'EXEC', 'PC', 'RETURN_DATA', 'VDATA'] and isinstance(val, UOp):
if var in ['D0', 'S0', 'SCC', 'VCC', 'EXEC', 'PC', 'RETURN_DATA', 'VDATA'] and isinstance(val, UOp):
if var in sliced and not any(re.match(rf'{var}\.\w+\s*=', l) for l in lines): continue
for l in lines:
if (m := re.match(rf'{var}\.(\w+(?:\[\w+\])?)', l)):
@ -462,7 +464,8 @@ class _Ctx:
return UOp.sink(*stores, *self.inc_pc())
def compile_vop_pcode(self, op, srcs: dict[str, UOp], lane: UOp, vdst_reg: UOp, exec_mask: UOp,
opsel_dst_hi: bool | UOp = False, sdst_reg: int | None = None, clmp: int = 0) -> UOp:
opsel_dst_hi: bool | UOp = False, sdst_reg: int | None = None, clmp: int = 0,
src0_off: UOp | None = None) -> UOp:
"""Compile VOP instruction. Returns sink with stores and inc_pc."""
pcode = get_pcode(op)
vcc_reg = sdst_reg if sdst_reg is not None else VCC_LO.offset
@ -519,11 +522,15 @@ class _Ctx:
result = opsel_dst_hi.where(hi_result, lo_result) if isinstance(opsel_dst_hi, UOp) else hi_result if opsel_dst_hi else lo_result
raw_stores.append(('vgpr', self.wvgpr_dyn(vdst_reg, lane, result, exec_mask)))
else: raw_stores.append(('vgpr', self.wvgpr_dyn(vdst_reg, lane, _val_to_u32(val), exec_mask)))
elif dest.startswith('S0') and src0_off is not None:
# Write back to src0 VGPR (e.g. v_swap_b32). src0_off is raw encoding (256+ = VGPR)
src0_vgpr = src0_off - _c(256)
raw_stores.append(('vgpr_s0', self.wvgpr_dyn(src0_vgpr, lane, _val_to_u32(val), exec_mask)))
elif dest.startswith('VCC'): vcc_val = val
elif dest.startswith('EXEC'): exec_val = val
elif dest.startswith('SCC'): raw_stores.append(('scc', self.wsgpr_dyn(_c(SCC.offset), _to_u32(val))))
stores, lane_stores, scalar_stores = [], [s for t, s in raw_stores if t == 'vgpr'], [s for t, s in raw_stores if t == 'scc']
stores, lane_stores, scalar_stores = [], [s for t, s in raw_stores if t in ('vgpr', 'vgpr_s0')], [s for t, s in raw_stores if t == 'scc']
slice_stores = [s for t, s in raw_stores if t == 'vgpr_slice']
if slice_stores:
result = self.rvgpr_dyn(vdst_reg, lane)
@ -548,6 +555,10 @@ def _compile_sopp(inst: ir3.SOPP | ir4.SOPP, ctx: _Ctx) -> UOp:
if inst.op in (ir3.SOPPOp.S_ENDPGM, ir4.SOPPOp.S_ENDPGM, irc.SOPPOp.S_ENDPGM):
return UOp.sink(ctx.wsgpr_dyn(_c(PC_LO_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)),
ctx.wsgpr_dyn(_c(PC_HI_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)))
# S_BARRIER: advance PC past the barrier instruction. The execution loop detects barriers before executing and handles synchronization.
barrier_ops = {ir3.SOPPOp.S_BARRIER, irc.SOPPOp.S_BARRIER}
if hasattr(ir4.SOPPOp, 'S_BARRIER_WAIT'): barrier_ops.add(ir4.SOPPOp.S_BARRIER_WAIT)
if inst.op in barrier_ops: return UOp.sink(*ctx.inc_pc())
# S_NOP and S_WAITCNT are no-ops in emulator (no pipeline/cache to wait on)
if inst.op in (ir3.SOPPOp.S_NOP, ir4.SOPPOp.S_NOP, irc.SOPPOp.S_NOP, irc.SOPPOp.S_WAITCNT): return UOp.sink(*ctx.inc_pc())
# NOTE: we ignore SOPPs without PCODE
@ -653,7 +664,8 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO
# Only compute hi-half when src0_off >= 384, use guarded index to prevent OOB access
src0_reg = src0_hi.where(src0_off - _c(384), _c(0))
s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0)
srcs = {'S0': s0}
d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane))
srcs = {'S0': s0, 'D0': d0}
else:
vsrc1_reg = ctx.inst_field(type(inst).vsrc1)
vsrc1_hi = bits['s0'] == 16 and (vsrc1_reg >= _c(128))
@ -675,7 +687,7 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO
if 'V_FMAA' in _op_name(inst) or 'V_FMAM' in _op_name(inst):
assert literal is not None
srcs['SIMM32'] = literal
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=write_hi_half)
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=write_hi_half, src0_off=src0_off)
def _compile_vopc(inst: ir3.VOPC|ir3.VOP3|ir4.VOPC|ir4.VOP3|irc.VOPC|irc.VOP3, ctx: _Ctx,
opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp:
@ -1022,7 +1034,7 @@ def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLA
vdst_reg = ctx.inst_field(type(inst).vdst)
offset0 = ctx.inst_field(type(inst).offset0) # type: ignore[union-attr]
offset1 = ctx.inst_field(type(inst).offset1) # type: ignore[union-attr]
offset = offset0 # DS uses offset0 as primary offset
offset = (offset1 << _c(8)) | offset0 # DS offset is 16-bit: (offset1 << 8) | offset0
saddr_reg = None
elif isinstance(inst, (ir4.VGLOBAL, ir4.VSCRATCH, ir4.VFLAT)): # RDNA4: vaddr, vsrc, ioffset
addr_reg = ctx.inst_field(type(inst).vaddr)
@ -1225,11 +1237,14 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
_canonical_runner_cache.append((base, mask, size, runner))
return runner
_BARRIER_OPS = {ir3.SOPPOp.S_BARRIER, irc.SOPPOp.S_BARRIER}
if hasattr(ir4.SOPPOp, 'S_BARRIER_WAIT'): _BARRIER_OPS.add(ir4.SOPPOp.S_BARRIER_WAIT)
def _decode_at(pc: int, arch: str):
"""Decode and compile instruction at absolute address pc. Returns CompiledRunner."""
"""Decode and compile instruction at absolute address pc. Returns (runner, decoded_inst)."""
inst_bytes = bytes((ctypes.c_char * 16).from_address(pc).raw)
inst = decode_inst(inst_bytes, arch)
try: return _get_runner(bytes(inst_bytes[:inst.size() + 4]), arch)
try: return _get_runner(bytes(inst_bytes[:inst.size() + 4]), arch), inst
except Exception as e:
try: inst_str = repr(inst)
except Exception: inst_str = f"<{type(inst).__name__}>"
@ -1279,10 +1294,37 @@ class WaveState:
# EXECUTION
# ═══════════════════════════════════════════════════════════════════════════════
def _init_wave(lib: int, wave_start: int, total_threads: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int,
scratch_size: int, arch: str, gidx: int, gidy: int, gidz: int, user_data: list[int]|None) -> tuple[WaveState, list]:
"""Initialize a single wavefront and return (WaveState, c_bufs placeholder). c_bufs filled in by caller."""
n_lanes = min(WAVE_SIZE, total_threads - wave_start)
st = WaveState(n_lanes)
st.pc = lib
if user_data:
for i, val in enumerate(user_data): st._write_sgpr(i, val)
else:
st._write_sgpr(0, args_ptr & MASK32)
st._write_sgpr(1, (args_ptr >> 32) & MASK32)
sgpr_idx = (rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT_SHIFT
for enabled, gid in [(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_X, gidx),
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Y, gidy),
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Z, gidz)]:
if rsrc2 & enabled:
st._write_sgpr(sgpr_idx, gid)
sgpr_idx += 1
if arch == "rdna4":
st._write_sgpr(ttmp[7].offset, (gidy & 0xFFFF) | ((gidz & 0xFFFF) << 16))
st._write_sgpr(ttmp[9].offset, gidx)
for lane in range(n_lanes):
tid = wave_start + lane
st._write_vgpr(0, lane, ((tid // (lx * ly)) << 20) | (((tid // lx) % ly) << 10) | (tid % lx))
st._write_sgpr(SCRATCH_STRIDE_IDX, scratch_size)
return st
def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c,
scratch_size: int = 0, arch: str = "rdna3", user_data: list[int]|None = None) -> int:
"""Execute AMD assembly program. scratch_size is private_segment_fixed_size from kernel descriptor (per-lane)."""
program: dict[int, tuple[Callable, list[int]]] = {} # lazily populated: pc -> (fxn, globals) extracted from runner
program: dict[int, tuple[Callable, list[int], bool]] = {} # pc -> (fxn, globals, is_barrier)
lds_size = ((rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE_SHIFT) * 512
total_threads = lx * ly * lz
@ -1291,56 +1333,51 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
lds_buf = Buffer('CPU', max(lds_size // 4, 1), dtypes.uint32).ensure_allocated()
scratch_buf = Buffer('CPU', scratch_size * WAVE_SIZE, dtypes.uint8).ensure_allocated() if scratch_size else None
def _ensure_compiled(pc: int) -> tuple[Callable, list[int], bool]:
if pc not in program:
prev_len = len(_canonical_runner_cache)
runner, inst = _decode_at(pc, arch)
is_barrier = isinstance(inst, (ir3.SOPP, ir4.SOPP, irc.SOPP)) and inst.op in _BARRIER_OPS
program[pc] = (runner._prg.fxn, runner.p.globals, is_barrier)
if DEBUG >= 3:
msg = f"[emu] PC={pc - lib}: {inst!r}"
print(colored(msg, 'green') if len(_canonical_runner_cache) > prev_len else msg)
return program[pc]
# Set DAZ+FTZ during emulator execution, restore afterward to avoid breaking hypothesis tests
with _MXCSRContext():
for gidz in range(gz):
for gidy in range(gy):
for gidx in range(gx):
# Initialize all wavefronts for this workgroup
waves: list[tuple[WaveState, list]] = []
for wave_start in range(0, total_threads, WAVE_SIZE):
n_lanes, st = min(WAVE_SIZE, total_threads - wave_start), WaveState(min(WAVE_SIZE, total_threads - wave_start))
st.pc = lib # Set PC to code base address
# Initialize user SGPRs: hardware loads COMPUTE_USER_DATA registers directly into s[0:N]
if user_data:
for i, val in enumerate(user_data): st._write_sgpr(i, val)
else:
st._write_sgpr(0, args_ptr & MASK32)
st._write_sgpr(1, (args_ptr >> 32) & MASK32)
# Workgroup IDs in SGPRs after user SGPRs
sgpr_idx = (rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT_SHIFT
for enabled, gid in [(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_X, gidx),
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Y, gidy),
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Z, gidz)]:
if rsrc2 & enabled:
st._write_sgpr(sgpr_idx, gid)
sgpr_idx += 1
# RDNA4 uses TTMP registers for workgroup IDs (16 bit) ttmp[9:7] -> gidx-gidz
if arch == "rdna4":
st._write_sgpr(ttmp[7].offset, (gidy & 0xFFFF) | ((gidz & 0xFFFF) << 16))
st._write_sgpr(ttmp[9].offset, gidx)
# v0 = packed workitem IDs, scratch stride in secret SGPR
for lane in range(n_lanes):
tid = wave_start + lane
st._write_vgpr(0, lane, ((tid // (lx * ly)) << 20) | (((tid // lx) % ly) << 10) | (tid % lx))
st._write_sgpr(SCRATCH_STRIDE_IDX, scratch_size)
# Pass buffer addresses via ctypes (pre-create to avoid allocation in loop)
st = _init_wave(lib, wave_start, total_threads, lx, ly, lz, args_ptr, rsrc2, scratch_size, arch, gidx, gidy, gidz, user_data)
c_bufs = [ctypes.c_uint64(st.sgpr_buf._buf.va_addr), ctypes.c_uint64(st.vgpr_buf._buf.va_addr),
ctypes.c_uint64(vmem_buf._buf.va_addr), ctypes.c_uint64(lds_buf._buf.va_addr),
ctypes.c_uint64(scratch_buf._buf.va_addr if scratch_buf else 0)]
for inst_count in range(1_000_000):
if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF: break
if pc not in program:
prev_len = len(_canonical_runner_cache)
runner = _decode_at(pc, arch)
program[pc] = (runner._prg.fxn, runner.p.globals)
if DEBUG >= 3:
inst = decode_inst(bytes((ctypes.c_char * 16).from_address(pc).raw), arch)
msg = f"[emu] PC={pc - lib}: {inst!r}"
print(colored(msg, 'green') if len(_canonical_runner_cache) > prev_len else msg)
fxn, globals_list = program[pc]
fxn(*[c_bufs[g] for g in globals_list])
else: raise RuntimeError("exceeded 1M instructions, likely infinite loop")
waves.append((st, c_bufs))
# Execute wavefronts with barrier synchronization
# Each wave runs until it hits s_barrier or s_endpgm. When all waves have stopped, release barrier waves.
done = [False] * len(waves)
for total_inst in range(10_000_000):
if all(done): break
for wi, (st, c_bufs) in enumerate(waves):
if done[wi]: continue
# Run this wave until barrier or endpgm
for _ in range(1_000_000):
pc = st.pc
if pc == ENDPGM_PC:
done[wi] = True
break
fxn, globals_list, is_barrier = _ensure_compiled(pc)
fxn(*[c_bufs[g] for g in globals_list])
if is_barrier: break # s_barrier hit: PC already advanced past it, pause this wave
else: raise RuntimeError("exceeded 1M instructions in single wave, likely infinite loop")
# All waves have either hit barrier or endpgm — release barrier waves for next round
else: raise RuntimeError("exceeded 10M total scheduling rounds")
# Reset LDS for next workgroup
if lds_size > 0: ctypes.memset(lds_buf._buf.va_addr, 0, max(lds_size, 4))
return 0