mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
continue_m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0eec0e3dc0 |
1 changed files with 10 additions and 7 deletions
|
|
@ -8,17 +8,20 @@ N = 4096
|
||||||
run_count = 5
|
run_count = 5
|
||||||
|
|
||||||
def hand_spec_kernel3():
|
def hand_spec_kernel3():
|
||||||
BLOCK_SIZE = 256
|
# block tile size
|
||||||
|
|
||||||
BN = 128
|
BN = 128
|
||||||
BM = 128
|
BM = 128
|
||||||
|
# number of row/column we read per batch
|
||||||
BK = 8
|
BK = 8
|
||||||
|
|
||||||
|
# thread tile size 4x4
|
||||||
TN = 4
|
TN = 4
|
||||||
TM = 4
|
TM = 4
|
||||||
|
|
||||||
|
BLOCK_SIZE = 128
|
||||||
nbWaves = BLOCK_SIZE // 32
|
nbWaves = BLOCK_SIZE // 32
|
||||||
WN = 64
|
# wave tile size
|
||||||
|
WN = 128
|
||||||
WM = BN * BM // nbWaves // WN
|
WM = BN * BM // nbWaves // WN
|
||||||
|
|
||||||
nbWaveX = BN // WN
|
nbWaveX = BN // WN
|
||||||
|
|
@ -65,7 +68,7 @@ def hand_spec_kernel3():
|
||||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), src=(junk,), arg=0)
|
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), src=(junk,), arg=0)
|
||||||
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), src=(junk,), arg=1)
|
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), src=(junk,), arg=1)
|
||||||
|
|
||||||
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0)
|
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*(BM+4), AddrSpace.LOCAL), arg=0) # 4 padding to avoid bank conflicts
|
||||||
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1)
|
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1)
|
||||||
|
|
||||||
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), src=(junk,), arg=2)
|
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), src=(junk,), arg=2)
|
||||||
|
|
@ -77,12 +80,12 @@ def hand_spec_kernel3():
|
||||||
i = UOp.range(dtypes.int, nbReadsB, 1)
|
i = UOp.range(dtypes.int, nbReadsB, 1)
|
||||||
index_x = BN * blockIdx_x + rBIdx
|
index_x = BN * blockIdx_x + rBIdx
|
||||||
index_y = rBIdy + i * strideReadB + kId
|
index_y = rBIdy + i * strideReadB + kId
|
||||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
|
Bs_store = Bs[(index_y % BK)*BN + index_x%BN].store(b[N * index_y + index_x].load(), i)
|
||||||
|
|
||||||
i = UOp.range(dtypes.int, nbReadsA, 2)
|
i = UOp.range(dtypes.int, nbReadsA, 2)
|
||||||
index_x = rAIdx + kId
|
index_x = rAIdx + kId
|
||||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||||
As_store = As[(index_x % BK) * BM + index_y % BM].store(a[N * index_y + index_x].load(), i)
|
As_store = As[(index_x % BK)*(BM+4) + index_y%BM].store(a[N * index_y + index_x].load(), i)
|
||||||
|
|
||||||
barrier = UOp(Ops.BARRIER, src=(As_store, Bs_store))
|
barrier = UOp(Ops.BARRIER, src=(As_store, Bs_store))
|
||||||
|
|
||||||
|
|
@ -97,7 +100,7 @@ def hand_spec_kernel3():
|
||||||
iterWave = UOp.range(dtypes.int, nbIterWaveM, 6)
|
iterWave = UOp.range(dtypes.int, nbIterWaveM, 6)
|
||||||
i = UOp.range(dtypes.int, TM, 7)
|
i = UOp.range(dtypes.int, TM, 7)
|
||||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||||
A_col_store = A_col[iterWave*TM + i].store(As[k*BM + index].load(barrier), iterWave, i)
|
A_col_store = A_col[iterWave*TM + i].store(As[k*(BM+4) + index].load(barrier), iterWave, i)
|
||||||
|
|
||||||
# do the GEMM math
|
# do the GEMM math
|
||||||
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8)
|
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue