Compare commits

...

9 commits

Author SHA1 Message Date
George Hotz
35932239d0
Merge branch 'master' into rdna4_gemm 2026-04-08 10:36:34 +08:00
qazal
f819f82c89 r4 2026-03-25 14:21:12 +02:00
qazal
be78f614f9
Merge branch 'master' into rdna4_gemm 2026-03-25 14:19:47 +02:00
George Hotz
683bb01ead 125 2026-03-23 23:19:15 +08:00
qazal
ccca6b8ecc test 2026-03-23 13:05:26 +00:00
qazal
e1039af42f more diff 2026-03-23 12:42:31 +00:00
qazal
793ad3e150 Merge remote-tracking branch 'upstream/master' into rdna4 2026-03-23 12:36:06 +00:00
qazal
003bd9534c diff cleanup 2026-03-23 12:36:03 +00:00
qazal
e0d151560a sqtt: rdna4 decoder work 2026-03-23 13:17:14 +02:00

View file

@ -0,0 +1,245 @@
# RDNA4 128x128 GEMM using WMMA — optimized DS scheduling
import numpy as np
from tinygrad import Tensor, Device, Context, GlobalCounters
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.helpers import getenv, colored
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.engine.realize import Estimates
from tinygrad.renderer.amd.dsl import s, v, VCC_LO, NULL, src, ttmp
from tinygrad.runtime.autogen.amd.rdna4.ins import *
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 16
TILES_M, TILES_N = 4, 4
THREADS, ELEM = 128, 2
LDS_A_ROW = BLOCK_K*ELEM # 32
LDS_B_ROW = BLOCK_N*ELEM # 256
LDS_A_SIZE = BLOCK_M * LDS_A_ROW # 4096
LDS_B_SIZE = BLOCK_K * LDS_B_ROW # 4096
LDS_SIZE = LDS_A_SIZE + LDS_B_SIZE # 8192
LDS_B_OFF = LDS_A_SIZE
ACC, DA, DB, FA, FB, ET = 60, 188, 196, 204, 44, 10
def build_kernel(N, arch='gfx1200'):
assert N % BLOCK_M == 0 and N >= 256
NO_ALU, NO_DS, NO_GLOBAL = getenv("NO_ALU", 0), getenv("NO_DS", 0), getenv("NO_GLOBAL", 0)
I, L, B = [], {}, []
def e(i): I.append(i); return i
def label(n): L[n] = sum(i.size() for i in I)
def br(i, t): B.append((len(I)-1, t))
e(s_load_b128(sdata=s[4:7], sbase=s[0:1], ioffset=0, soffset=NULL))
e(s_load_b64(sdata=s[8:9], sbase=s[0:1], ioffset=0x10, soffset=NULL))
e(s_wait_kmcnt(simm16=0))
e(s_mov_b32(s[10], ttmp[9])); e(s_and_b32(s[11], ttmp[7], 0xFFFF))
e(s_lshl_b32(s[10], s[10], 7)); e(s_lshl_b32(s[11], s[11], 7))
e(s_mov_b32(s[12], N)); e(s_lshl_b32(s[13], s[12], 1))
e(s_mul_i32(s[14], s[12], BLOCK_K*ELEM))
e(s_add_co_i32(s[17], s[12], -2*BLOCK_K)) # loop bound
e(v_and_b32_e32(v[1], 31, v[0])); e(v_lshrrev_b32_e32(v[2], 5, v[0]))
e(v_and_b32_e32(v[3], 1, v[2])); e(v_lshrrev_b32_e32(v[2], 1, v[2]))
e(v_lshlrev_b32_e32(v[4], 5, v[0]))
# B store: transposed layout for stride-32 reads. addr = LDS_B_OFF + (tid%8)*512 + (tid/8)*32
e(v_and_b32_e32(v[48], 7, v[0])); e(v_lshlrev_b32_e32(v[5], 9, v[48])) # (tid%8)*512
e(v_lshrrev_b32_e32(v[48], 3, v[0])); e(v_lshlrev_b32_e32(v[48], 5, v[48])) # (tid/8)*32
e(v_add_nc_u32_e32(v[5], v[5], v[48])); e(v_add_nc_u32_e32(v[5], LDS_B_OFF, v[5]))
e(v_add_nc_u32_e32(v[48], s[11], v[0]))
e(v_mul_lo_u32(v[6], v[48], N*ELEM)); e(v_mov_b32_e32(v[7], 0))
e(v_lshrrev_b32_e32(v[48], 3, v[0])); e(v_mul_lo_u32(v[8], v[48], N*ELEM))
e(v_and_b32_e32(v[48], 7, v[0])); e(v_lshlrev_b32_e32(v[48], 5, v[48]))
e(v_add_nc_u32_e32(v[8], v[8], v[48]))
e(s_mul_i32(s[15], s[10], ELEM)); e(v_add_nc_u32_e32(v[8], s[15], v[8]))
e(v_mov_b32_e32(v[9], 0))
# LDS read addrs with padded strides (eliminates bank conflicts)
# A: (lane%16)*LDS_A_ROW + (lane/16)*16 + wave_m*64*LDS_A_ROW
# B: (lane%16)*LDS_B_ROW + (lane/16)*16 + wave_n*64*ELEM + LDS_B_OFF
LLA, LLB = 40, 43
e(v_and_b32_e32(v[50], 15, v[1])); e(v_lshrrev_b32_e32(v[51], 4, v[1]))
e(v_lshlrev_b32_e32(v[LLA], 5, v[50])) # (lane%16) * 32
e(v_lshlrev_b32_e32(v[51], 4, v[51])) # (lane/16) * 16
e(v_add_nc_u32_e32(v[LLA], v[LLA], v[51]))
e(v_lshlrev_b32_e32(v[52], 11, v[2])) # wave_m * 2048
e(v_add_nc_u32_e32(v[LLA], v[LLA], v[52]))
# B read: transposed layout. addr = LDS_B_OFF + (lane%16)*32 + (lane/16)*16 + wave_n*2*512
# wave_n selects column panels: wave_n*2 panels (each panel=16 cols, wave_n covers 64 cols = 4 panels)
# But wave_n*2*512 = wave_n*1024. Hmm, wave_n covers cols [wave_n*64 : (wave_n+1)*64].
# Each panel = 16 cols = 512 bytes. wave_n*64/16 = wave_n*4 panels. Offset = wave_n*4*512 = wave_n*2048.
e(v_lshlrev_b32_e32(v[LLB], 5, v[50])) # (lane%16) * 32 (stride 32!)
e(v_add_nc_u32_e32(v[LLB], v[LLB], v[51])) # + (lane/16)*16
e(v_lshlrev_b32_e32(v[52], 11, v[3])) # wave_n * 2048
e(v_add_nc_u32_e32(v[LLB], v[LLB], v[52]))
e(v_add_nc_u32_e32(v[LLB], LDS_B_OFF, v[LLB]))
for i in range(0, 128, 2):
e(VOPD(VOPDOp.V_DUAL_MOV_B32, VOPDOp.V_DUAL_MOV_B32, vdstx=v[ACC+i], vdsty=v[ACC+i+1], srcx0=0, srcy0=0))
e(s_mov_b32(s[16], 0))
if not NO_GLOBAL:
for i in range(2): e(global_load_b128(vdst=v[DA+i*4:DA+i*4+3], vaddr=v[6:7], saddr=s[4:5], ioffset=i*16))
for i in range(2): e(global_load_b128(vdst=v[DB+i*4:DB+i*4+3], vaddr=v[8:9], saddr=s[6:7], ioffset=i*16))
e(s_wait_loadcnt(simm16=0))
if not NO_DS:
for i in range(2): e(ds_store_b128(addr=v[4], data0=v[DA+i*4:DA+i*4+3], offset0=(i*16)&0xFF, offset1=(i*16)>>8))
for i in range(2): e(ds_store_b128(addr=v[5], data0=v[DB+i*4:DB+i*4+3], offset0=(i*16)&0xFF, offset1=(i*16)>>8))
if not NO_GLOBAL:
e(v_add_nc_u32_e32(v[6], BLOCK_K*ELEM, v[6]))
e(v_add_nc_u32_e32(v[8], s[14], v[8]))
# =============================================================================
def emit_iter_body(load_set='AB'):
if not NO_DS:
e(s_wait_dscnt(simm16=0))
e(s_barrier_signal(ssrc0=src[193])); e(s_barrier_wait(simm16=0xFFFF))
if not NO_GLOBAL:
if 'A' in load_set:
for i in range(2): e(global_load_b128(vdst=v[DA+i*4:DA+i*4+3], vaddr=v[6:7], saddr=s[4:5], ioffset=i*16))
e(v_add_nc_u32_e32(v[6], BLOCK_K*ELEM, v[6]))
if 'B' in load_set:
for i in range(2): e(global_load_b128(vdst=v[DB+i*4:DB+i*4+3], vaddr=v[8:9], saddr=s[6:7], ioffset=i*16))
e(v_add_nc_u32_e32(v[8], s[14], v[8]))
if not NO_DS:
# Issue 6 loads: A[0:3] + B[0] + B[1]. B[2:3] interleaved with WMMAs.
for tm in range(TILES_M):
aoff = tm * 16 * LDS_A_ROW
e(ds_load_b128(vdst=v[FA+tm*4:FA+tm*4+3], addr=v[LLA], offset0=aoff&0xFF, offset1=aoff>>8))
e(ds_load_b128(vdst=v[FB:FB+3], addr=v[LLB], offset0=0, offset1=0))
e(ds_load_b128(vdst=v[FB+4:FB+7], addr=v[LLB], offset0=0, offset1=2))
e(s_wait_dscnt(simm16=0)) # wait for 6 loads (no stall!)
if not NO_ALU:
# B[0] WMMAs — issue B[2] during compute
if not NO_DS: e(ds_load_b128(vdst=v[FB+8:FB+11], addr=v[LLB], offset0=0, offset1=4))
for tm in range(TILES_M):
ac = ACC + (tm*TILES_N+0)*8
e(v_wmma_f32_16x16x16_f16(vdst=v[ac:ac+7], src0=v[FA+tm*4:FA+tm*4+3], src1=v[FB:FB+3], src2=v[ac:ac+7]))
# B[1] WMMAs — issue B[3] during compute
if not NO_DS:
e(ds_load_b128(vdst=v[FB+12:FB+15], addr=v[LLB], offset0=0, offset1=6))
for tm in range(TILES_M):
ac = ACC + (tm*TILES_N+1)*8
e(v_wmma_f32_16x16x16_f16(vdst=v[ac:ac+7], src0=v[FA+tm*4:FA+tm*4+3], src1=v[FB+4:FB+7], src2=v[ac:ac+7]))
# B[2] WMMAs — B[2] loaded during B[0] WMMAs (~100 cycles ago)
if not NO_DS: e(s_wait_dscnt(simm16=1)) # B[2] done, B[3] may still be loading
for tm in range(TILES_M):
ac = ACC + (tm*TILES_N+2)*8
e(v_wmma_f32_16x16x16_f16(vdst=v[ac:ac+7], src0=v[FA+tm*4:FA+tm*4+3], src1=v[FB+8:FB+11], src2=v[ac:ac+7]))
# B[3] WMMAs
if not NO_DS: e(s_wait_dscnt(simm16=0))
for tm in range(TILES_M):
ac = ACC + (tm*TILES_N+3)*8
e(v_wmma_f32_16x16x16_f16(vdst=v[ac:ac+7], src0=v[FA+tm*4:FA+tm*4+3], src1=v[FB+12:FB+15], src2=v[ac:ac+7]))
if not NO_GLOBAL and not NO_DS: e(s_wait_loadcnt(simm16=0))
if not NO_DS:
for i in range(2): e(ds_store_b128(addr=v[4], data0=v[DA+i*4:DA+i*4+3], offset0=(i*16)&0xFF, offset1=(i*16)>>8))
for i in range(2): e(ds_store_b128(addr=v[5], data0=v[DB+i*4:DB+i*4+3], offset0=(i*16)&0xFF, offset1=(i*16)>>8))
e(s_add_co_i32(s[16], s[16], BLOCK_K))
label('LOOP')
emit_iter_body(load_set='A')
emit_iter_body(load_set='B')
e(s_cmp_lt_i32(s[16], s[17])); e(s_cbranch_scc1(simm16=0)); br(I[-1], 'LOOP')
emit_iter_body(load_set='AB') # tail with prefetch
# Final iteration: no prefetch, no ds_store needed
if not NO_DS:
e(s_wait_dscnt(simm16=0))
e(s_barrier_signal(ssrc0=src[193])); e(s_barrier_wait(simm16=0xFFFF))
if not NO_DS:
for tm in range(TILES_M):
aoff = tm * 16 * LDS_A_ROW
e(ds_load_b128(vdst=v[FA+tm*4:FA+tm*4+3], addr=v[LLA], offset0=aoff&0xFF, offset1=aoff>>8))
e(ds_load_b128(vdst=v[FB:FB+3], addr=v[LLB], offset0=0, offset1=0))
e(ds_load_b128(vdst=v[FB+4:FB+7], addr=v[LLB], offset0=0, offset1=2))
e(s_wait_dscnt(simm16=0))
if not NO_ALU:
if not NO_DS: e(ds_load_b128(vdst=v[FB+8:FB+11], addr=v[LLB], offset0=0, offset1=4))
for tm in range(TILES_M):
ac = ACC + (tm*TILES_N+0)*8
e(v_wmma_f32_16x16x16_f16(vdst=v[ac:ac+7], src0=v[FA+tm*4:FA+tm*4+3], src1=v[FB:FB+3], src2=v[ac:ac+7]))
if not NO_DS: e(ds_load_b128(vdst=v[FB+12:FB+15], addr=v[LLB], offset0=0, offset1=6))
for tm in range(TILES_M):
ac = ACC + (tm*TILES_N+1)*8
e(v_wmma_f32_16x16x16_f16(vdst=v[ac:ac+7], src0=v[FA+tm*4:FA+tm*4+3], src1=v[FB+4:FB+7], src2=v[ac:ac+7]))
if not NO_DS: e(s_wait_dscnt(simm16=1))
for tm in range(TILES_M):
ac = ACC + (tm*TILES_N+2)*8
e(v_wmma_f32_16x16x16_f16(vdst=v[ac:ac+7], src0=v[FA+tm*4:FA+tm*4+3], src1=v[FB+8:FB+11], src2=v[ac:ac+7]))
if not NO_DS: e(s_wait_dscnt(simm16=0))
for tm in range(TILES_M):
ac = ACC + (tm*TILES_N+3)*8
e(v_wmma_f32_16x16x16_f16(vdst=v[ac:ac+7], src0=v[FA+tm*4:FA+tm*4+3], src1=v[FB+12:FB+15], src2=v[ac:ac+7]))
label('EPILOGUE')
e(v_and_b32_e32(v[ET], 15, v[1]))
e(v_lshrrev_b32_e32(v[ET+1], 4, v[1])); e(v_lshlrev_b32_e32(v[ET+1], 3, v[ET+1]))
e(v_lshlrev_b32_e32(v[ET+2], 6, v[2])); e(v_add_nc_u32_e32(v[ET+2], s[11], v[ET+2]))
e(v_lshlrev_b32_e32(v[ET+3], 6, v[3])); e(v_add_nc_u32_e32(v[ET+3], s[10], v[ET+3]))
e(v_add_nc_u32_e32(v[ET+3], v[ET+3], v[ET])); e(v_mov_b32_e32(v[ET+5], 0))
for tm in range(TILES_M):
for tn in range(TILES_N):
ac = ACC + (tm*TILES_N+tn)*8; r_off, c_off = tm*16, tn*16
e(v_add_nc_u32_e32(v[ET+6], r_off, v[ET+2])); e(v_add_nc_u32_e32(v[ET+6], v[ET+1], v[ET+6]))
e(v_mul_lo_u32(v[ET+4], v[ET+6], s[12])); e(v_add_nc_u32_e32(v[ET+4], v[ET+4], v[ET+3]))
if c_off: e(v_add_nc_u32_e32(v[ET+4], c_off, v[ET+4]))
e(v_lshlrev_b32_e32(v[ET+4], 1, v[ET+4]))
for elem in range(8):
e(v_cvt_f16_f32_e32(v[ET+7], v[ac+elem]))
e(global_store_b16(vaddr=v[ET+4:ET+5], vsrc=v[ET+7], saddr=s[8:9]))
if elem < 7: e(v_add_nc_u32_e32(v[ET+4], s[13], v[ET+4]))
e(s_wait_storecnt(simm16=0)); e(s_sendmsg(simm16=3)); e(s_endpgm())
for idx, target in B:
off = (L[target] - sum(i.size() for i in I[:idx+1])) // 4
assert -32768 <= off <= 32767; I[idx].simm16 = off
return I
N = getenv("N", 4096)
def test_matmul():
dev = Device[Device.DEFAULT]
arch = getattr(dev.renderer, 'arch', 'gfx1200')
print(f"Device arch: {arch}")
insts = build_kernel(N, arch)
rng = np.random.default_rng(42)
a = Tensor(rng.random((N, N), dtype=np.float32).astype(np.float16))
b = Tensor(rng.random((N, N), dtype=np.float32).astype(np.float16))
c = Tensor.empty(N, N, dtype=dtypes.half)
Tensor.realize(a, b, c)
grid, local = (N//BLOCK_N, N//BLOCK_M, 1), (THREADS, 1, 1)
print(f"Grid: {grid}, Local: {local}")
dname = Device.DEFAULT
def asm_kernel(A, B, C):
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
lidxs = [UOp.special(THREADS, "lidx0")]
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2)), addrspace=AddrSpace.LOCAL), (), 'lds')
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs,
arg=KernelInfo(name=colored("kernel","cyan"), estimates=Estimates(ops=N*N*N*2, mem=N*N*2*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
c = Tensor.custom_kernel(a, b, c, fxn=asm_kernel)[2]
ei = c.schedule()[0].lower()
ets = []
with Context(DEBUG=2):
for _ in range(getenv("CNT", 5)): ets.append(ei.run(wait=True))
print(f"REAL TFLOPS {N*N*N*2 / min(ets) * 1e-12:.2f}")
if getenv("VERIFY", 1):
GlobalCounters.reset()
c_np = c.float().numpy()
a_np, b_np = a.float().numpy(), b.float().numpy()
ref = a_np @ b_np
err = np.sqrt(np.mean((c_np - ref)**2)) / np.sqrt(np.mean(ref**2))
print(f"relative RMSE {err:.6f}")
if err != err or err > 0.05: raise RuntimeError(f"matmul is wrong! RMSE={err}")
if __name__ == "__main__":
test_matmul()