mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
amd asm emulator fixes + run it in CI (#14786)
* amd asm fix, try 2 * fix tests
This commit is contained in:
parent
55a4dfa2e0
commit
dff9cf35c2
6 changed files with 283 additions and 97 deletions
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
|
|
@ -659,6 +659,8 @@ jobs:
|
|||
run: |
|
||||
PYTHONPATH=. NULL=1 EMULATE=AMD python extra/mmapeak/mmapeak.py
|
||||
PYTHONPATH=. NULL=1 EMULATE=AMD_CDNA4 python3 -m pytest -n=auto test/testextra/test_tk.py test/testextra/test_asm_gemm.py
|
||||
- name: Run ASM matmul on MOCKGPU
|
||||
run: PYTHONPATH="." AMD=1 MOCKGPU=1 N=256 python3 extra/gemm/amd_asm_matmul.py
|
||||
- name: Run LLVM test
|
||||
run: AMD_LLVM=1 python test/device/test_amd_llvm.py
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# RDNA3 128x128 tiled GEMM kernel - DSL version
|
||||
# Computes C = A @ B for 4096x4096 float32 matrices using 128x128 tiles
|
||||
# Computes C = A @ B for NxN float32 matrices using 128x128 tiles
|
||||
#
|
||||
# Architecture: RDNA3 (gfx1100)
|
||||
# Tile size: 128x128 (each workgroup computes one tile of C)
|
||||
|
|
@ -21,7 +21,6 @@ from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
|||
# Kernel constants
|
||||
# =============================================================================
|
||||
LDS_SIZE = 8320 # Local data share size in bytes
|
||||
MATRIX_DIM = 4096 # Matrix dimension N (assumes square NxN matrices)
|
||||
LDS_A_STRIDE = 0x210 # LDS stride for A tile (528 bytes)
|
||||
LDS_B_STRIDE = 0x200 # LDS stride for B tile (512 bytes)
|
||||
LDS_BASE_OFFSET = 0x1080 # Base LDS offset for tiles
|
||||
|
|
@ -62,7 +61,7 @@ S_TILE_Y = 15 # workgroup_y << 7
|
|||
# Kernarg load destinations
|
||||
S_KERNARG_A = (20, 21) # A pointer from kernarg
|
||||
S_KERNARG_B = (22, 23) # B pointer from kernarg
|
||||
# Prefetch base pointers (8 pairs each, 16KB/256KB apart)
|
||||
# Prefetch base pointers (8 pairs each, B: N*4 bytes apart, A: N*64 bytes apart)
|
||||
S_PREFETCH_B = 24 # s[24:39] - 8 B tile pointers
|
||||
S_PREFETCH_A = 40 # s[40:55] - 8 A tile pointers
|
||||
|
||||
|
|
@ -197,7 +196,9 @@ class Kernel:
|
|||
# Kernel builder
|
||||
# =============================================================================
|
||||
|
||||
def build_kernel(arch='gfx1100'):
|
||||
def build_kernel(N, arch='gfx1100'):
|
||||
assert N % 128 == 0, f"N must be a multiple of 128 (tile size), got {N}"
|
||||
assert N >= 256, f"N must be >= 256 (prefetch pipeline requires at least 2 K-blocks), got {N}"
|
||||
k = Kernel(arch)
|
||||
|
||||
# ===========================================================================
|
||||
|
|
@ -205,7 +206,7 @@ def build_kernel(arch='gfx1100'):
|
|||
# ===========================================================================
|
||||
k.emit(s_load_b128(sdata=s[S_KERNARG_A[0]:S_KERNARG_B[1]], sbase=s[0:1], offset=0x0, soffset=NULL))
|
||||
k.emit(s_load_b64(sdata=s[S_OUT_PTR[0]:S_OUT_PTR[1]], sbase=s[0:1], offset=0x10, soffset=NULL))
|
||||
k.emit(s_mov_b32(s[S_DIM_N], MATRIX_DIM))
|
||||
k.emit(s_mov_b32(s[S_DIM_N], N))
|
||||
k.emit(s_mov_b32(s[S_LOOP_CTR], 0)) # used by LDS swizzle, always 0 for valid workgroups
|
||||
k.emit(s_lshl_b32(s[S_TILE_X], s[S_WORKGROUP_X], 7))
|
||||
k.emit(s_lshl_b32(s[S_TILE_Y], s[S_WORKGROUP_Y], 7))
|
||||
|
|
@ -220,19 +221,20 @@ def build_kernel(arch='gfx1100'):
|
|||
|
||||
# Compute 8 A and B matrix tile base pointers for prefetch
|
||||
k.emit(s_mov_b64(s[S_PREFETCH_B:S_PREFETCH_B+1], s[S_KERNARG_B[0]:S_KERNARG_B[1]])) # B[0]: no offset
|
||||
for i in range(1, 8): # B: 16KB apart
|
||||
k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_KERNARG_B[0]], i * 0x4000))
|
||||
for i in range(1, 8): # B: each pointer 1 row of B apart (N*4 bytes)
|
||||
k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_KERNARG_B[0]], i * N * 4))
|
||||
k.emit(s_addc_u32(s[S_PREFETCH_B+i*2+1], s[S_KERNARG_B[1]], 0))
|
||||
k.emit(s_mov_b64(s[S_PREFETCH_A:S_PREFETCH_A+1], s[S_KERNARG_A[0]:S_KERNARG_A[1]])) # A[0]: no offset
|
||||
for i in range(1, 8): # A: 256KB apart
|
||||
k.emit(s_add_u32(s[S_PREFETCH_A+i*2], s[S_KERNARG_A[0]], i * 0x40000))
|
||||
for i in range(1, 8): # A: each pointer 16 rows of A apart (16*N*4 bytes)
|
||||
k.emit(s_add_u32(s[S_PREFETCH_A+i*2], s[S_KERNARG_A[0]], i * N * 64))
|
||||
k.emit(s_addc_u32(s[S_PREFETCH_A+i*2+1], s[S_KERNARG_A[1]], 0))
|
||||
|
||||
# Global prefetch addresses: B = (tile_x + lane_id) * 4, A = ((tile_y << 12) + (lane_id/8)*4K + lane_id%8) * 4
|
||||
# Global prefetch addresses: B = (tile_x + lane_id) * 4, A = (tile_y*N + (lane_id/8)*N + lane_id%8) * 4
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], s[S_TILE_X], v[V_LANE_ID]))
|
||||
k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_B_ADDR], 2, v[V_GLOBAL_B_ADDR]))
|
||||
k.emit(s_lshl_b32(s[19], s[S_TILE_Y], 12))
|
||||
k.emit(v_lshl_add_u32(v[V_GLOBAL_A_ADDR], v[4], 12, v[V_LANE_ID_MOD8])) # (lane_id/8)*4K + lane_id%8
|
||||
k.emit(s_mul_i32(s[19], s[S_TILE_Y], N))
|
||||
k.emit(v_mul_lo_u32(v[V_GLOBAL_A_ADDR], v[4], N)) # (lane_id/8)*N
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], v[V_LANE_ID_MOD8], v[V_GLOBAL_A_ADDR])) # + lane_id%8
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], s[19], v[V_GLOBAL_A_ADDR]))
|
||||
k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_A_ADDR], 2, v[V_GLOBAL_A_ADDR]))
|
||||
|
||||
|
|
@ -303,13 +305,13 @@ def build_kernel(arch='gfx1100'):
|
|||
|
||||
if not NO_GLOBAL:
|
||||
# Advance prefetch pointers (VGPR)
|
||||
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], 0x20000, v[V_GLOBAL_B_ADDR]))
|
||||
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], N * 32, v[V_GLOBAL_B_ADDR]))
|
||||
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR]))
|
||||
|
||||
# Advance prefetch pointers (64-bit adds)
|
||||
# Advance prefetch pointers (64-bit adds): B advances 8 rows (8*N*4 bytes), A advances 8 cols (8*4 bytes)
|
||||
k.emit(s_clause(simm16=31))
|
||||
for i in range(8):
|
||||
k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_PREFETCH_B+i*2], 0x20000))
|
||||
k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_PREFETCH_B+i*2], N * 32))
|
||||
k.emit(s_addc_u32(s[S_PREFETCH_B+i*2+1], s[S_PREFETCH_B+i*2+1], 0))
|
||||
for i in range(8):
|
||||
k.emit(s_add_u32(s[S_PREFETCH_A+i*2], s[S_PREFETCH_A+i*2], 0x20))
|
||||
|
|
@ -440,7 +442,7 @@ def test_matmul():
|
|||
dev = Device[Device.DEFAULT]
|
||||
print(f"Device arch: {dev.renderer.arch}")
|
||||
|
||||
insts = build_kernel(dev.renderer.arch)
|
||||
insts = build_kernel(N, dev.renderer.arch)
|
||||
|
||||
rng = np.random.default_rng(42)
|
||||
a = Tensor(rng.random((N, N), dtype=np.float32) - 0.5)
|
||||
|
|
@ -472,33 +474,23 @@ def test_matmul():
|
|||
with Context(DEBUG=2): tc = (a @ b).realize()
|
||||
with Context(DEBUG=0): err = (c - tc).square().mean().item()
|
||||
print(f"mean squared error {err}")
|
||||
if err != err or err > 1e-06: raise RuntimeError("matmul is wrong!")
|
||||
|
||||
def run_sqtt():
|
||||
"""Run with SQTT profiling and write trace files."""
|
||||
import subprocess, os
|
||||
|
||||
# Run test_matmul in a subprocess with SQTT enabled from the start (no verify)
|
||||
env = {**os.environ, "AMD": "1", "SQTT": "1", "CNT": "1", "PROFILE": "1", "PYTHONPATH": ".", "VERIFY": "0"}
|
||||
result = subprocess.run(
|
||||
["python", "-c", "from extra.gemm.amd_asm_matmul import test_matmul; test_matmul()"],
|
||||
capture_output=True, text=True, env=env, timeout=120
|
||||
)
|
||||
print(result.stdout)
|
||||
|
||||
# Run roc.py to extract trace data
|
||||
result = subprocess.run(
|
||||
["python", "extra/sqtt/roc.py", "--profile", "/tmp/profile.pkl.tiny", "--kernel", "kernel"],
|
||||
capture_output=True, text=True, env={**os.environ, "DEBUG": "5"}, timeout=60
|
||||
)
|
||||
output = result.stdout + result.stderr
|
||||
|
||||
# Write full output to trace file
|
||||
with open("/tmp/sqtt_trace.txt", "w") as f:
|
||||
f.write(output)
|
||||
print(f"Wrote {len(output)} bytes to /tmp/sqtt_trace.txt")
|
||||
if err != err or err > 1e-06:
|
||||
c_np, tc_np = c.numpy(), tc.numpy()
|
||||
for bi in range(N // 128):
|
||||
for bj in range(N // 128):
|
||||
blk_c = c_np[bi*128:(bi+1)*128, bj*128:(bj+1)*128]
|
||||
blk_ref = tc_np[bi*128:(bi+1)*128, bj*128:(bj+1)*128]
|
||||
blk_diff = blk_c - blk_ref
|
||||
zero_rows = [i for i in range(128) if np.all(np.abs(blk_c[i,:]) < 1e-10)]
|
||||
nz_rows = [i for i in range(128) if i not in zero_rows]
|
||||
nz_mse = float(np.mean(blk_diff[nz_rows,:]**2)) if nz_rows else 0
|
||||
print(f"Block ({bi},{bj}): zero_rows={zero_rows}, nz_rows_mse={nz_mse:.2e}")
|
||||
# show first few non-zero row comparisons
|
||||
if nz_rows and nz_mse > 1e-6:
|
||||
for r in nz_rows[:3]:
|
||||
print(f" row {r} asm[0:8]: {blk_c[r,:8]}")
|
||||
print(f" row {r} ref[0:8]: {blk_ref[r,:8]}")
|
||||
raise RuntimeError("matmul is wrong!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if getenv("ASM", 0): print("\n".join(str(inst) for inst in build_kernel(Device[Device.DEFAULT].renderer.arch)))
|
||||
elif getenv("SQTT", 0): run_sqtt()
|
||||
else: test_matmul()
|
||||
test_matmul()
|
||||
|
|
|
|||
|
|
@ -760,5 +760,110 @@ class TestDsPermute(unittest.TestCase):
|
|||
self.assertEqual(st.vgpr[0][2], 0x11111111)
|
||||
|
||||
|
||||
class TestDSLargeOffset(unittest.TestCase):
|
||||
"""Tests for DS instructions with offsets > 255 (offset1 > 0).
|
||||
|
||||
The DS offset is a 16-bit value encoded as (offset1 << 8) | offset0.
|
||||
These tests verify that offset1 is used correctly, not just offset0.
|
||||
"""
|
||||
|
||||
def test_ds_store_load_b32_offset_256(self):
|
||||
"""DS_STORE_B32/DS_LOAD_B32 with offset=256 (offset0=0, offset1=1)."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[0], 0xDEADBEEF),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
ds_store_b32(addr=v[10], data0=v[0], offset0=0, offset1=1), # offset = 256
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b32(addr=v[10], vdst=v[1], offset0=0, offset1=1), # offset = 256
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][1], 0xDEADBEEF)
|
||||
|
||||
def test_ds_store_load_b32_offset_300(self):
|
||||
"""DS_STORE_B32/DS_LOAD_B32 with offset=300 (offset0=44, offset1=1)."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[0], 0xCAFEBABE),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
ds_store_b32(addr=v[10], data0=v[0], offset0=44, offset1=1), # offset = 300
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b32(addr=v[10], vdst=v[1], offset0=44, offset1=1), # offset = 300
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][1], 0xCAFEBABE)
|
||||
|
||||
def test_ds_store_load_b64_offset_512(self):
|
||||
"""DS_STORE_B64/DS_LOAD_B64 with offset=512 (offset0=0, offset1=2)."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[0], 0x11111111),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
s_mov_b32(s[0], 0x22222222),
|
||||
v_mov_b32_e32(v[1], s[0]),
|
||||
ds_store_b64(addr=v[10], data0=v[0:1], offset0=0, offset1=2), # offset = 512
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b64(addr=v[10], vdst=v[2:3], offset0=0, offset1=2), # offset = 512
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0x11111111)
|
||||
self.assertEqual(st.vgpr[0][3], 0x22222222)
|
||||
|
||||
def test_ds_large_offset_distinct_from_small(self):
|
||||
"""Verify offset=256 and offset=0 address different LDS locations."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[0], 0xAAAAAAAA),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
s_mov_b32(s[0], 0xBBBBBBBB),
|
||||
v_mov_b32_e32(v[1], s[0]),
|
||||
# Store 0xAAAAAAAA at offset=0, 0xBBBBBBBB at offset=256
|
||||
ds_store_b32(addr=v[10], data0=v[0], offset0=0, offset1=0), # offset = 0
|
||||
ds_store_b32(addr=v[10], data0=v[1], offset0=0, offset1=1), # offset = 256
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
# Read back both
|
||||
ds_load_b32(addr=v[10], vdst=v[2], offset0=0, offset1=0), # offset = 0
|
||||
ds_load_b32(addr=v[10], vdst=v[3], offset0=0, offset1=1), # offset = 256
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0xAAAAAAAA, "offset=0 should read 0xAAAAAAAA")
|
||||
self.assertEqual(st.vgpr[0][3], 0xBBBBBBBB, "offset=256 should read 0xBBBBBBBB")
|
||||
|
||||
def test_ds_store_load_b32_offset_448(self):
|
||||
"""DS_STORE_B32/DS_LOAD_B32 with offset=448 (offset0=192, offset1=1) - matches matmul B tile."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[0], 0x12345678),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
ds_store_b32(addr=v[10], data0=v[0], offset0=192, offset1=1), # offset = 448
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b32(addr=v[10], vdst=v[1], offset0=192, offset1=1), # offset = 448
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][1], 0x12345678)
|
||||
|
||||
def test_ds_load_b64_offset_392(self):
|
||||
"""DS_LOAD_B64 with offset=392 (offset0=136, offset1=1) - matches matmul B tile load."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], 0),
|
||||
s_mov_b32(s[0], 0xAABBCCDD),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
s_mov_b32(s[0], 0x11223344),
|
||||
v_mov_b32_e32(v[1], s[0]),
|
||||
ds_store_b64(addr=v[10], data0=v[0:1], offset0=136, offset1=1), # offset = 392
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
ds_load_b64(addr=v[10], vdst=v[2:3], offset0=136, offset1=1), # offset = 392
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0xAABBCCDD)
|
||||
self.assertEqual(st.vgpr[0][3], 0x11223344)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1579,5 +1579,55 @@ class TestPermlane64(unittest.TestCase):
|
|||
self.assertEqual(st.vgpr[0][1], 0x12345678)
|
||||
|
||||
|
||||
class TestSwap(unittest.TestCase):
|
||||
"""Tests for V_SWAP_B32 - swap two VGPRs."""
|
||||
|
||||
def test_v_swap_b32_basic(self):
|
||||
"""V_SWAP_B32 swaps two VGPR values."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 42),
|
||||
v_mov_b32_e32(v[1], 99),
|
||||
v_swap_b32_e32(v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 99)
|
||||
self.assertEqual(st.vgpr[0][1], 42)
|
||||
|
||||
def test_v_swap_b32_same_reg(self):
|
||||
"""V_SWAP_B32 with same src and dst is a no-op."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0xDEADBEEF),
|
||||
v_swap_b32_e32(v[0], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 0xDEADBEEF)
|
||||
|
||||
def test_v_swap_b32_multi_lane(self):
|
||||
"""V_SWAP_B32 swaps per-lane values independently."""
|
||||
instructions = [
|
||||
# v[0] = lane_id * 10, v[1] = lane_id * 100
|
||||
v_lshlrev_b32_e32(v[0], 1, v[255]), # v[0] = lane_id * 2
|
||||
v_add_nc_u32_e32(v[0], v[0], v[255]), # v[0] = lane_id * 3
|
||||
v_mul_u32_u24_e32(v[1], 100, v[255]), # v[1] = lane_id * 100
|
||||
v_swap_b32_e32(v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=4)
|
||||
for lane in range(4):
|
||||
self.assertEqual(st.vgpr[lane][0], lane * 100)
|
||||
self.assertEqual(st.vgpr[lane][1], lane * 3)
|
||||
|
||||
def test_v_swap_b32_chain(self):
|
||||
"""Two swaps in sequence restore original values."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0xAAAAAAAA),
|
||||
v_mov_b32_e32(v[1], 0x55555555),
|
||||
v_swap_b32_e32(v[0], v[1]),
|
||||
v_swap_b32_e32(v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][0], 0xAAAAAAAA)
|
||||
self.assertEqual(st.vgpr[0][1], 0x55555555)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ class PythonEmulator:
|
|||
|
||||
def _ensure_decoded(self, pc: int):
|
||||
if pc not in self.program:
|
||||
runner = _decode_at(pc, "rdna3")
|
||||
runner, _ = _decode_at(pc, "rdna3")
|
||||
self.program[pc] = (runner.p.function_name, runner._prg.fxn, runner.p.globals)
|
||||
|
||||
def step(self) -> int:
|
||||
|
|
|
|||
|
|
@ -116,6 +116,8 @@ WAVE_SIZE = 32
|
|||
PC_LO_IDX, PC_HI_IDX, SCRATCH_STRIDE_IDX = 256, 257, 259
|
||||
# SGPR buffer: 0-127 = SGPRs, 128-255 = inline constants, 256-259 = special registers
|
||||
SGPR_COUNT, VGPR_SIZE = 260, 256 * 32
|
||||
# Sentinel PC value for s_endpgm
|
||||
ENDPGM_PC = 0xFFFFFFFFFFFFFFFF
|
||||
|
||||
def _op_name(inst) -> str:
|
||||
if hasattr(inst, 'opx'): return f"{inst.opx.name}_{inst.opy.name}" # VOPD has opx/opy not op
|
||||
|
|
@ -210,7 +212,7 @@ def parse_pcode(pcode: str, srcs: dict[str, UOp] | None = None) -> tuple[dict, l
|
|||
_, final, _ = parse_block(lines, 0, env, assigns=assigns)
|
||||
sliced = set(d.split('[')[0] for d, _ in assigns if '[' in d)
|
||||
for var, val in final.items():
|
||||
if var in ['D0', 'SCC', 'VCC', 'EXEC', 'PC', 'RETURN_DATA', 'VDATA'] and isinstance(val, UOp):
|
||||
if var in ['D0', 'S0', 'SCC', 'VCC', 'EXEC', 'PC', 'RETURN_DATA', 'VDATA'] and isinstance(val, UOp):
|
||||
if var in sliced and not any(re.match(rf'{var}\.\w+\s*=', l) for l in lines): continue
|
||||
for l in lines:
|
||||
if (m := re.match(rf'{var}\.(\w+(?:\[\w+\])?)', l)):
|
||||
|
|
@ -462,7 +464,8 @@ class _Ctx:
|
|||
return UOp.sink(*stores, *self.inc_pc())
|
||||
|
||||
def compile_vop_pcode(self, op, srcs: dict[str, UOp], lane: UOp, vdst_reg: UOp, exec_mask: UOp,
|
||||
opsel_dst_hi: bool | UOp = False, sdst_reg: int | None = None, clmp: int = 0) -> UOp:
|
||||
opsel_dst_hi: bool | UOp = False, sdst_reg: int | None = None, clmp: int = 0,
|
||||
src0_off: UOp | None = None) -> UOp:
|
||||
"""Compile VOP instruction. Returns sink with stores and inc_pc."""
|
||||
pcode = get_pcode(op)
|
||||
vcc_reg = sdst_reg if sdst_reg is not None else VCC_LO.offset
|
||||
|
|
@ -519,11 +522,15 @@ class _Ctx:
|
|||
result = opsel_dst_hi.where(hi_result, lo_result) if isinstance(opsel_dst_hi, UOp) else hi_result if opsel_dst_hi else lo_result
|
||||
raw_stores.append(('vgpr', self.wvgpr_dyn(vdst_reg, lane, result, exec_mask)))
|
||||
else: raw_stores.append(('vgpr', self.wvgpr_dyn(vdst_reg, lane, _val_to_u32(val), exec_mask)))
|
||||
elif dest.startswith('S0') and src0_off is not None:
|
||||
# Write back to src0 VGPR (e.g. v_swap_b32). src0_off is raw encoding (256+ = VGPR)
|
||||
src0_vgpr = src0_off - _c(256)
|
||||
raw_stores.append(('vgpr_s0', self.wvgpr_dyn(src0_vgpr, lane, _val_to_u32(val), exec_mask)))
|
||||
elif dest.startswith('VCC'): vcc_val = val
|
||||
elif dest.startswith('EXEC'): exec_val = val
|
||||
elif dest.startswith('SCC'): raw_stores.append(('scc', self.wsgpr_dyn(_c(SCC.offset), _to_u32(val))))
|
||||
|
||||
stores, lane_stores, scalar_stores = [], [s for t, s in raw_stores if t == 'vgpr'], [s for t, s in raw_stores if t == 'scc']
|
||||
stores, lane_stores, scalar_stores = [], [s for t, s in raw_stores if t in ('vgpr', 'vgpr_s0')], [s for t, s in raw_stores if t == 'scc']
|
||||
slice_stores = [s for t, s in raw_stores if t == 'vgpr_slice']
|
||||
if slice_stores:
|
||||
result = self.rvgpr_dyn(vdst_reg, lane)
|
||||
|
|
@ -548,6 +555,10 @@ def _compile_sopp(inst: ir3.SOPP | ir4.SOPP, ctx: _Ctx) -> UOp:
|
|||
if inst.op in (ir3.SOPPOp.S_ENDPGM, ir4.SOPPOp.S_ENDPGM, irc.SOPPOp.S_ENDPGM):
|
||||
return UOp.sink(ctx.wsgpr_dyn(_c(PC_LO_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)),
|
||||
ctx.wsgpr_dyn(_c(PC_HI_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)))
|
||||
# S_BARRIER: advance PC past the barrier instruction. The execution loop detects barriers before executing and handles synchronization.
|
||||
barrier_ops = {ir3.SOPPOp.S_BARRIER, irc.SOPPOp.S_BARRIER}
|
||||
if hasattr(ir4.SOPPOp, 'S_BARRIER_WAIT'): barrier_ops.add(ir4.SOPPOp.S_BARRIER_WAIT)
|
||||
if inst.op in barrier_ops: return UOp.sink(*ctx.inc_pc())
|
||||
# S_NOP and S_WAITCNT are no-ops in emulator (no pipeline/cache to wait on)
|
||||
if inst.op in (ir3.SOPPOp.S_NOP, ir4.SOPPOp.S_NOP, irc.SOPPOp.S_NOP, irc.SOPPOp.S_WAITCNT): return UOp.sink(*ctx.inc_pc())
|
||||
# NOTE: we ignore SOPPs without PCODE
|
||||
|
|
@ -653,7 +664,8 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO
|
|||
# Only compute hi-half when src0_off >= 384, use guarded index to prevent OOB access
|
||||
src0_reg = src0_hi.where(src0_off - _c(384), _c(0))
|
||||
s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0)
|
||||
srcs = {'S0': s0}
|
||||
d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane))
|
||||
srcs = {'S0': s0, 'D0': d0}
|
||||
else:
|
||||
vsrc1_reg = ctx.inst_field(type(inst).vsrc1)
|
||||
vsrc1_hi = bits['s0'] == 16 and (vsrc1_reg >= _c(128))
|
||||
|
|
@ -675,7 +687,7 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO
|
|||
if 'V_FMAA' in _op_name(inst) or 'V_FMAM' in _op_name(inst):
|
||||
assert literal is not None
|
||||
srcs['SIMM32'] = literal
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=write_hi_half)
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=write_hi_half, src0_off=src0_off)
|
||||
|
||||
def _compile_vopc(inst: ir3.VOPC|ir3.VOP3|ir4.VOPC|ir4.VOP3|irc.VOPC|irc.VOP3, ctx: _Ctx,
|
||||
opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp:
|
||||
|
|
@ -1022,7 +1034,7 @@ def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLA
|
|||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
offset0 = ctx.inst_field(type(inst).offset0) # type: ignore[union-attr]
|
||||
offset1 = ctx.inst_field(type(inst).offset1) # type: ignore[union-attr]
|
||||
offset = offset0 # DS uses offset0 as primary offset
|
||||
offset = (offset1 << _c(8)) | offset0 # DS offset is 16-bit: (offset1 << 8) | offset0
|
||||
saddr_reg = None
|
||||
elif isinstance(inst, (ir4.VGLOBAL, ir4.VSCRATCH, ir4.VFLAT)): # RDNA4: vaddr, vsrc, ioffset
|
||||
addr_reg = ctx.inst_field(type(inst).vaddr)
|
||||
|
|
@ -1225,11 +1237,14 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
|
|||
_canonical_runner_cache.append((base, mask, size, runner))
|
||||
return runner
|
||||
|
||||
_BARRIER_OPS = {ir3.SOPPOp.S_BARRIER, irc.SOPPOp.S_BARRIER}
|
||||
if hasattr(ir4.SOPPOp, 'S_BARRIER_WAIT'): _BARRIER_OPS.add(ir4.SOPPOp.S_BARRIER_WAIT)
|
||||
|
||||
def _decode_at(pc: int, arch: str):
|
||||
"""Decode and compile instruction at absolute address pc. Returns CompiledRunner."""
|
||||
"""Decode and compile instruction at absolute address pc. Returns (runner, decoded_inst)."""
|
||||
inst_bytes = bytes((ctypes.c_char * 16).from_address(pc).raw)
|
||||
inst = decode_inst(inst_bytes, arch)
|
||||
try: return _get_runner(bytes(inst_bytes[:inst.size() + 4]), arch)
|
||||
try: return _get_runner(bytes(inst_bytes[:inst.size() + 4]), arch), inst
|
||||
except Exception as e:
|
||||
try: inst_str = repr(inst)
|
||||
except Exception: inst_str = f"<{type(inst).__name__}>"
|
||||
|
|
@ -1279,10 +1294,37 @@ class WaveState:
|
|||
# EXECUTION
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _init_wave(lib: int, wave_start: int, total_threads: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int,
|
||||
scratch_size: int, arch: str, gidx: int, gidy: int, gidz: int, user_data: list[int]|None) -> tuple[WaveState, list]:
|
||||
"""Initialize a single wavefront and return (WaveState, c_bufs placeholder). c_bufs filled in by caller."""
|
||||
n_lanes = min(WAVE_SIZE, total_threads - wave_start)
|
||||
st = WaveState(n_lanes)
|
||||
st.pc = lib
|
||||
if user_data:
|
||||
for i, val in enumerate(user_data): st._write_sgpr(i, val)
|
||||
else:
|
||||
st._write_sgpr(0, args_ptr & MASK32)
|
||||
st._write_sgpr(1, (args_ptr >> 32) & MASK32)
|
||||
sgpr_idx = (rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT_SHIFT
|
||||
for enabled, gid in [(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_X, gidx),
|
||||
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Y, gidy),
|
||||
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Z, gidz)]:
|
||||
if rsrc2 & enabled:
|
||||
st._write_sgpr(sgpr_idx, gid)
|
||||
sgpr_idx += 1
|
||||
if arch == "rdna4":
|
||||
st._write_sgpr(ttmp[7].offset, (gidy & 0xFFFF) | ((gidz & 0xFFFF) << 16))
|
||||
st._write_sgpr(ttmp[9].offset, gidx)
|
||||
for lane in range(n_lanes):
|
||||
tid = wave_start + lane
|
||||
st._write_vgpr(0, lane, ((tid // (lx * ly)) << 20) | (((tid // lx) % ly) << 10) | (tid % lx))
|
||||
st._write_sgpr(SCRATCH_STRIDE_IDX, scratch_size)
|
||||
return st
|
||||
|
||||
def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c,
|
||||
scratch_size: int = 0, arch: str = "rdna3", user_data: list[int]|None = None) -> int:
|
||||
"""Execute AMD assembly program. scratch_size is private_segment_fixed_size from kernel descriptor (per-lane)."""
|
||||
program: dict[int, tuple[Callable, list[int]]] = {} # lazily populated: pc -> (fxn, globals) extracted from runner
|
||||
program: dict[int, tuple[Callable, list[int], bool]] = {} # pc -> (fxn, globals, is_barrier)
|
||||
lds_size = ((rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE_SHIFT) * 512
|
||||
total_threads = lx * ly * lz
|
||||
|
||||
|
|
@ -1291,56 +1333,51 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
|
|||
lds_buf = Buffer('CPU', max(lds_size // 4, 1), dtypes.uint32).ensure_allocated()
|
||||
scratch_buf = Buffer('CPU', scratch_size * WAVE_SIZE, dtypes.uint8).ensure_allocated() if scratch_size else None
|
||||
|
||||
def _ensure_compiled(pc: int) -> tuple[Callable, list[int], bool]:
|
||||
if pc not in program:
|
||||
prev_len = len(_canonical_runner_cache)
|
||||
runner, inst = _decode_at(pc, arch)
|
||||
is_barrier = isinstance(inst, (ir3.SOPP, ir4.SOPP, irc.SOPP)) and inst.op in _BARRIER_OPS
|
||||
program[pc] = (runner._prg.fxn, runner.p.globals, is_barrier)
|
||||
if DEBUG >= 3:
|
||||
msg = f"[emu] PC={pc - lib}: {inst!r}"
|
||||
print(colored(msg, 'green') if len(_canonical_runner_cache) > prev_len else msg)
|
||||
return program[pc]
|
||||
|
||||
# Set DAZ+FTZ during emulator execution, restore afterward to avoid breaking hypothesis tests
|
||||
with _MXCSRContext():
|
||||
for gidz in range(gz):
|
||||
for gidy in range(gy):
|
||||
for gidx in range(gx):
|
||||
# Initialize all wavefronts for this workgroup
|
||||
waves: list[tuple[WaveState, list]] = []
|
||||
for wave_start in range(0, total_threads, WAVE_SIZE):
|
||||
n_lanes, st = min(WAVE_SIZE, total_threads - wave_start), WaveState(min(WAVE_SIZE, total_threads - wave_start))
|
||||
st.pc = lib # Set PC to code base address
|
||||
# Initialize user SGPRs: hardware loads COMPUTE_USER_DATA registers directly into s[0:N]
|
||||
if user_data:
|
||||
for i, val in enumerate(user_data): st._write_sgpr(i, val)
|
||||
else:
|
||||
st._write_sgpr(0, args_ptr & MASK32)
|
||||
st._write_sgpr(1, (args_ptr >> 32) & MASK32)
|
||||
|
||||
# Workgroup IDs in SGPRs after user SGPRs
|
||||
sgpr_idx = (rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT_SHIFT
|
||||
for enabled, gid in [(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_X, gidx),
|
||||
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Y, gidy),
|
||||
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Z, gidz)]:
|
||||
if rsrc2 & enabled:
|
||||
st._write_sgpr(sgpr_idx, gid)
|
||||
sgpr_idx += 1
|
||||
|
||||
# RDNA4 uses TTMP registers for workgroup IDs (16 bit) ttmp[9:7] -> gidx-gidz
|
||||
if arch == "rdna4":
|
||||
st._write_sgpr(ttmp[7].offset, (gidy & 0xFFFF) | ((gidz & 0xFFFF) << 16))
|
||||
st._write_sgpr(ttmp[9].offset, gidx)
|
||||
|
||||
# v0 = packed workitem IDs, scratch stride in secret SGPR
|
||||
for lane in range(n_lanes):
|
||||
tid = wave_start + lane
|
||||
st._write_vgpr(0, lane, ((tid // (lx * ly)) << 20) | (((tid // lx) % ly) << 10) | (tid % lx))
|
||||
st._write_sgpr(SCRATCH_STRIDE_IDX, scratch_size)
|
||||
|
||||
# Pass buffer addresses via ctypes (pre-create to avoid allocation in loop)
|
||||
st = _init_wave(lib, wave_start, total_threads, lx, ly, lz, args_ptr, rsrc2, scratch_size, arch, gidx, gidy, gidz, user_data)
|
||||
c_bufs = [ctypes.c_uint64(st.sgpr_buf._buf.va_addr), ctypes.c_uint64(st.vgpr_buf._buf.va_addr),
|
||||
ctypes.c_uint64(vmem_buf._buf.va_addr), ctypes.c_uint64(lds_buf._buf.va_addr),
|
||||
ctypes.c_uint64(scratch_buf._buf.va_addr if scratch_buf else 0)]
|
||||
for inst_count in range(1_000_000):
|
||||
if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF: break
|
||||
if pc not in program:
|
||||
prev_len = len(_canonical_runner_cache)
|
||||
runner = _decode_at(pc, arch)
|
||||
program[pc] = (runner._prg.fxn, runner.p.globals)
|
||||
if DEBUG >= 3:
|
||||
inst = decode_inst(bytes((ctypes.c_char * 16).from_address(pc).raw), arch)
|
||||
msg = f"[emu] PC={pc - lib}: {inst!r}"
|
||||
print(colored(msg, 'green') if len(_canonical_runner_cache) > prev_len else msg)
|
||||
fxn, globals_list = program[pc]
|
||||
fxn(*[c_bufs[g] for g in globals_list])
|
||||
else: raise RuntimeError("exceeded 1M instructions, likely infinite loop")
|
||||
waves.append((st, c_bufs))
|
||||
|
||||
# Execute wavefronts with barrier synchronization
|
||||
# Each wave runs until it hits s_barrier or s_endpgm. When all waves have stopped, release barrier waves.
|
||||
done = [False] * len(waves)
|
||||
for total_inst in range(10_000_000):
|
||||
if all(done): break
|
||||
for wi, (st, c_bufs) in enumerate(waves):
|
||||
if done[wi]: continue
|
||||
# Run this wave until barrier or endpgm
|
||||
for _ in range(1_000_000):
|
||||
pc = st.pc
|
||||
if pc == ENDPGM_PC:
|
||||
done[wi] = True
|
||||
break
|
||||
fxn, globals_list, is_barrier = _ensure_compiled(pc)
|
||||
fxn(*[c_bufs[g] for g in globals_list])
|
||||
if is_barrier: break # s_barrier hit: PC already advanced past it, pause this wave
|
||||
else: raise RuntimeError("exceeded 1M instructions in single wave, likely infinite loop")
|
||||
# All waves have either hit barrier or endpgm — release barrier waves for next round
|
||||
else: raise RuntimeError("exceeded 10M total scheduling rounds")
|
||||
|
||||
# Reset LDS for next workgroup
|
||||
if lds_size > 0: ctypes.memset(lds_buf._buf.va_addr, 0, max(lds_size, 4))
|
||||
return 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue