continue work on amd_uop_matmul

This commit is contained in:
George Hotz 2025-07-23 18:49:19 -07:00
commit 0eec0e3dc0

View file

@ -8,17 +8,20 @@ N = 4096
run_count = 5
def hand_spec_kernel3():
BLOCK_SIZE = 256
# block tile size
BN = 128
BM = 128
# number of row/column we read per batch
BK = 8
# thread tile size 4x4
TN = 4
TM = 4
BLOCK_SIZE = 128
nbWaves = BLOCK_SIZE // 32
WN = 64
# wave tile size
WN = 128
WM = BN * BM // nbWaves // WN
nbWaveX = BN // WN
@ -65,7 +68,7 @@ def hand_spec_kernel3():
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), src=(junk,), arg=0)
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), src=(junk,), arg=1)
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0)
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*(BM+4), AddrSpace.LOCAL), arg=0) # 4 padding to avoid bank conflicts
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1)
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), src=(junk,), arg=2)
@ -77,12 +80,12 @@ def hand_spec_kernel3():
i = UOp.range(dtypes.int, nbReadsB, 1)
index_x = BN * blockIdx_x + rBIdx
index_y = rBIdy + i * strideReadB + kId
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
Bs_store = Bs[(index_y % BK)*BN + index_x%BN].store(b[N * index_y + index_x].load(), i)
i = UOp.range(dtypes.int, nbReadsA, 2)
index_x = rAIdx + kId
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
As_store = As[(index_x % BK) * BM + index_y % BM].store(a[N * index_y + index_x].load(), i)
As_store = As[(index_x % BK)*(BM+4) + index_y%BM].store(a[N * index_y + index_x].load(), i)
barrier = UOp(Ops.BARRIER, src=(As_store, Bs_store))
@ -97,7 +100,7 @@ def hand_spec_kernel3():
iterWave = UOp.range(dtypes.int, nbIterWaveM, 6)
i = UOp.range(dtypes.int, TM, 7)
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
A_col_store = A_col[iterWave*TM + i].store(As[k*BM + index].load(barrier), iterWave, i)
A_col_store = A_col[iterWave*TM + i].store(As[k*(BM+4) + index].load(barrier), iterWave, i)
# do the GEMM math
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8)