mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
continue work on amd_uop_matmul
This commit is contained in:
parent
b0dc97d1f7
commit
0eec0e3dc0
1 changed files with 10 additions and 7 deletions
|
|
@ -8,17 +8,20 @@ N = 4096
|
|||
run_count = 5
|
||||
|
||||
def hand_spec_kernel3():
|
||||
BLOCK_SIZE = 256
|
||||
|
||||
# block tile size
|
||||
BN = 128
|
||||
BM = 128
|
||||
# number of row/column we read per batch
|
||||
BK = 8
|
||||
|
||||
# thread tile size 4x4
|
||||
TN = 4
|
||||
TM = 4
|
||||
|
||||
BLOCK_SIZE = 128
|
||||
nbWaves = BLOCK_SIZE // 32
|
||||
WN = 64
|
||||
# wave tile size
|
||||
WN = 128
|
||||
WM = BN * BM // nbWaves // WN
|
||||
|
||||
nbWaveX = BN // WN
|
||||
|
|
@ -65,7 +68,7 @@ def hand_spec_kernel3():
|
|||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), src=(junk,), arg=0)
|
||||
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), src=(junk,), arg=1)
|
||||
|
||||
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0)
|
||||
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*(BM+4), AddrSpace.LOCAL), arg=0) # 4 padding to avoid bank conflicts
|
||||
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1)
|
||||
|
||||
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), src=(junk,), arg=2)
|
||||
|
|
@ -77,12 +80,12 @@ def hand_spec_kernel3():
|
|||
i = UOp.range(dtypes.int, nbReadsB, 1)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId
|
||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
|
||||
Bs_store = Bs[(index_y % BK)*BN + index_x%BN].store(b[N * index_y + index_x].load(), i)
|
||||
|
||||
i = UOp.range(dtypes.int, nbReadsA, 2)
|
||||
index_x = rAIdx + kId
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[(index_x % BK) * BM + index_y % BM].store(a[N * index_y + index_x].load(), i)
|
||||
As_store = As[(index_x % BK)*(BM+4) + index_y%BM].store(a[N * index_y + index_x].load(), i)
|
||||
|
||||
barrier = UOp(Ops.BARRIER, src=(As_store, Bs_store))
|
||||
|
||||
|
|
@ -97,7 +100,7 @@ def hand_spec_kernel3():
|
|||
iterWave = UOp.range(dtypes.int, nbIterWaveM, 6)
|
||||
i = UOp.range(dtypes.int, TM, 7)
|
||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||
A_col_store = A_col[iterWave*TM + i].store(As[k*BM + index].load(barrier), iterWave, i)
|
||||
A_col_store = A_col[iterWave*TM + i].store(As[k*(BM+4) + index].load(barrier), iterWave, i)
|
||||
|
||||
# do the GEMM math
|
||||
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue