mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
7 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c89fb841f8 | ||
|
|
29f7d03d43 | ||
|
|
b41e49472c | ||
|
|
165d8e1263 | ||
|
|
4b57aa2655 | ||
|
|
ad1a2a68d5 | ||
|
|
1e07dff384 |
7 changed files with 308 additions and 51 deletions
|
|
@ -19,6 +19,9 @@ if __name__ == "__main__":
|
||||||
elif getenv("ASM") == -1:
|
elif getenv("ASM") == -1:
|
||||||
src = (pathlib.Path(__file__).parent / "amd_seb" / "kernel3_registers.cpp").read_text()
|
src = (pathlib.Path(__file__).parent / "amd_seb" / "kernel3_registers.cpp").read_text()
|
||||||
prgfast = replace(prg, name="kernel3_registers", src=src, global_size=[N//128, N//128, 1], local_size=[256, 1, 1])
|
prgfast = replace(prg, name="kernel3_registers", src=src, global_size=[N//128, N//128, 1], local_size=[256, 1, 1])
|
||||||
|
elif getenv("ASM") == -2:
|
||||||
|
src = (pathlib.Path(__file__).parent / "amd_seb" / "kernel4_gmem_df.cpp").read_text()
|
||||||
|
prgfast = replace(prg, name="kernel4_gmem_db", src=src, global_size=[N//128, N//128, 1], local_size=[256, 1, 1])
|
||||||
else:
|
else:
|
||||||
src = (pathlib.Path(__file__).parent / "amd_seb" / "kernel5_lds_optim.cpp").read_text()
|
src = (pathlib.Path(__file__).parent / "amd_seb" / "kernel5_lds_optim.cpp").read_text()
|
||||||
prgfast = replace(prg, name="kernel5_lds_optim", src=src, global_size=[N//128, N//128, 1], local_size=[128, 1, 1])
|
prgfast = replace(prg, name="kernel5_lds_optim", src=src, global_size=[N//128, N//128, 1], local_size=[128, 1, 1])
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,8 @@ __attribute__((device)) inline void __syncthreads() {
|
||||||
}
|
}
|
||||||
|
|
||||||
#define BLOCK_SIZE 256
|
#define BLOCK_SIZE 256
|
||||||
extern "C" __attribute__((global)) void kernel3_registers(float *a, float *b, float *c)
|
extern "C" __attribute__((global)) void __attribute__((amdgpu_flat_work_group_size(1, BLOCK_SIZE)))
|
||||||
|
kernel3_registers(float *a, float *b, float *c)
|
||||||
{
|
{
|
||||||
constexpr int N = 4096;
|
constexpr int N = 4096;
|
||||||
constexpr float alpha = 1.0;
|
constexpr float alpha = 1.0;
|
||||||
|
|
|
||||||
172
extra/gemm/amd_seb/kernel4_gmem_df.cpp
Normal file
172
extra/gemm/amd_seb/kernel4_gmem_df.cpp
Normal file
|
|
@ -0,0 +1,172 @@
|
||||||
|
typedef long unsigned int size_t;
|
||||||
|
extern "C" __attribute__((device, const)) size_t __ockl_get_local_id(unsigned int);
|
||||||
|
extern "C" __attribute__((device, const)) size_t __ockl_get_group_id(unsigned int);
|
||||||
|
struct Dim3 { size_t x, y, z; };
|
||||||
|
#define __shared__ __attribute__((shared, aligned(16)))
|
||||||
|
__attribute__((device)) inline void __syncthreads() {
|
||||||
|
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");
|
||||||
|
__builtin_amdgcn_s_barrier();
|
||||||
|
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");
|
||||||
|
}
|
||||||
|
|
||||||
|
#define BLOCK_SIZE 256
|
||||||
|
extern "C" __attribute__((global)) void __attribute__((amdgpu_flat_work_group_size(1, BLOCK_SIZE)))
|
||||||
|
kernel4_gmem_db(float *a, float *b, float *c)
|
||||||
|
{
|
||||||
|
constexpr int N = 4096;
|
||||||
|
constexpr float alpha = 1.0;
|
||||||
|
constexpr float beta = 0.0;
|
||||||
|
|
||||||
|
const Dim3 blockIdx{ __ockl_get_group_id(0), __ockl_get_group_id(1), __ockl_get_group_id(2) };
|
||||||
|
const Dim3 threadIdx{ __ockl_get_local_id(0), __ockl_get_local_id(1), __ockl_get_local_id(2) };
|
||||||
|
|
||||||
|
// Block Tile size
|
||||||
|
constexpr int BN = 128;
|
||||||
|
constexpr int BM = 128;
|
||||||
|
// Number of Row or column we read per batch
|
||||||
|
constexpr int BK = 8;
|
||||||
|
|
||||||
|
// Thread Tile size
|
||||||
|
constexpr int TN = 4;
|
||||||
|
constexpr int TM = 4;
|
||||||
|
|
||||||
|
constexpr int nbWaves = BLOCK_SIZE / 32;
|
||||||
|
// Wave Tile size
|
||||||
|
constexpr int WN = 64;
|
||||||
|
constexpr int WM = BN * BM / nbWaves / WN;
|
||||||
|
|
||||||
|
// Number of wave on X & Y axis in the Block tile
|
||||||
|
constexpr int nbWaveX = BN / WN;
|
||||||
|
constexpr int nbWaveY = BM / WM;
|
||||||
|
|
||||||
|
const int waveIndex = threadIdx.x / 32;
|
||||||
|
const int waveIdx = waveIndex % nbWaveX;
|
||||||
|
const int waveIdy = waveIndex / nbWaveX;
|
||||||
|
const int indexInWave = threadIdx.x % 32;
|
||||||
|
|
||||||
|
// A wave is a block of 8x4 of the output matrix
|
||||||
|
constexpr int nbThreadXPerWave = 8;
|
||||||
|
constexpr int nbThreadYPerWave = 4;
|
||||||
|
|
||||||
|
// Thread coordinates in Wave
|
||||||
|
const int idxInWave = indexInWave % nbThreadXPerWave;
|
||||||
|
const int idyInWave = indexInWave / nbThreadXPerWave;
|
||||||
|
|
||||||
|
constexpr int nbIterWaveN = WN / (nbThreadXPerWave * TN);
|
||||||
|
constexpr int nbIterWaveM = WM / (nbThreadYPerWave * TM);
|
||||||
|
|
||||||
|
// Wave Sub-tile size
|
||||||
|
constexpr int SUBWN = WN / nbIterWaveN;
|
||||||
|
constexpr int SUBWM = WM / nbIterWaveM;
|
||||||
|
|
||||||
|
// Thread mapping to read BKxBN block from A
|
||||||
|
int rAIdx = threadIdx.x % BK;
|
||||||
|
int rAIdy = threadIdx.x / BK;
|
||||||
|
// Thread mapping to read BNxBK block from B
|
||||||
|
int rBIdx = threadIdx.x % BN;
|
||||||
|
int rBIdy = threadIdx.x / BN;
|
||||||
|
|
||||||
|
constexpr int strideReadB = BLOCK_SIZE / BN;
|
||||||
|
constexpr int strideReadA = BLOCK_SIZE / BK;
|
||||||
|
constexpr int nbReadsB = BN * BK / BLOCK_SIZE;
|
||||||
|
constexpr int nbReadsA = BM * BK / BLOCK_SIZE;
|
||||||
|
|
||||||
|
float A_col[nbIterWaveM * TM];
|
||||||
|
float B_row[nbIterWaveN * TN];
|
||||||
|
|
||||||
|
__shared__ float As[BK][BM];
|
||||||
|
__shared__ float Bs[BK][BN];
|
||||||
|
|
||||||
|
float c_regs[TM * nbIterWaveM * TN * nbIterWaveN] = {0.0f};
|
||||||
|
|
||||||
|
for (int i = 0; i < nbReadsB; i++) {
|
||||||
|
int index_x = BN * blockIdx.x + rBIdx;
|
||||||
|
int index_y = rBIdy + i * strideReadB;
|
||||||
|
Bs[index_y % BK][index_x % BN] = b[N * index_y + index_x];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < nbReadsA; i++) {
|
||||||
|
int index_x = rAIdx;
|
||||||
|
int index_y = BM * blockIdx.y + rAIdy + i * strideReadA;
|
||||||
|
As[(index_x % BK)][(index_y % BM)] = a[N * index_y + index_x];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
// Iteration over BK blocks.
|
||||||
|
for (int kId = 0; kId < N; kId += BK) {
|
||||||
|
float regA[nbReadsA];
|
||||||
|
float regB[nbReadsB];
|
||||||
|
if (kId < N - BK) {
|
||||||
|
// We populate the Shared Memory with Ks row and columns
|
||||||
|
for (int i = 0; i < nbReadsB; i++) {
|
||||||
|
int index_x = BN * blockIdx.x + rBIdx;
|
||||||
|
int index_y = rBIdy + i * strideReadB + kId + BK;
|
||||||
|
regB[i] = b[N * index_y + index_x];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < nbReadsA; i++) {
|
||||||
|
int index_x = rAIdx + kId + BK;
|
||||||
|
int index_y = BM * blockIdx.y + rAIdy + i * strideReadA;
|
||||||
|
regA[i] = a[N * index_y + index_x];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int k = 0; k < BK; k++) {
|
||||||
|
// we cache A & B for the entire Wave tile
|
||||||
|
for (int iterWave = 0; iterWave < nbIterWaveN; iterWave++) {
|
||||||
|
for (int i = 0; i < TN; i++) {
|
||||||
|
int index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i;
|
||||||
|
B_row[iterWave * TN + i] = Bs[k][index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int iterWave = 0; iterWave < nbIterWaveM; iterWave++) {
|
||||||
|
for (int i = 0; i < TM; i++) {
|
||||||
|
int index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i;
|
||||||
|
A_col[iterWave * TM + i] = As[k][index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// we accumulate to C_regs
|
||||||
|
for (int iterWaveM = 0; iterWaveM < nbIterWaveM; iterWaveM++) {
|
||||||
|
for (int iterWaveN = 0; iterWaveN < nbIterWaveN; iterWaveN++) {
|
||||||
|
for (int yt = 0; yt < TM; yt++) {
|
||||||
|
for (int xt = 0; xt < TN; xt++) {
|
||||||
|
const int x = iterWaveN * TN + xt;
|
||||||
|
const int y = iterWaveM * TM + yt;
|
||||||
|
c_regs[y * TN * nbIterWaveN + x] += A_col[y] * B_row[x];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
if (kId < N - BK) {
|
||||||
|
for (int i = 0; i < nbReadsB; i++) {
|
||||||
|
int index_x = BN * blockIdx.x + rBIdx;
|
||||||
|
int index_y = rBIdy + i * strideReadB + kId + BK;
|
||||||
|
Bs[index_y % BK][index_x % BN] = regB[i]; // row
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < nbReadsA; i++) {
|
||||||
|
int index_x = rAIdx + kId + BK;
|
||||||
|
int index_y = BM * blockIdx.y + rAIdy + i * strideReadA;
|
||||||
|
As[(index_x % BK)][(index_y % BM)] = regA[i];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int iterWaveM = 0; iterWaveM < nbIterWaveM; iterWaveM++) {
|
||||||
|
for (int iterWaveN = 0; iterWaveN < nbIterWaveN; iterWaveN++) {
|
||||||
|
int xOut = blockIdx.x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave;
|
||||||
|
int yOut = blockIdx.y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave;
|
||||||
|
for (int yt = 0; yt < TM; yt++) {
|
||||||
|
for (int xt = 0; xt < TN; xt++) {
|
||||||
|
int indexC = N * (yOut + yt) + xOut + xt;
|
||||||
|
c[indexC] = beta * c[indexC] + alpha * c_regs[TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -26,7 +26,7 @@ kernel5_lds_optim(float *a, float *b, float *c)
|
||||||
// Number of Row or column we read per batch
|
// Number of Row or column we read per batch
|
||||||
constexpr int BK = 8;
|
constexpr int BK = 8;
|
||||||
|
|
||||||
// Thread Tile size . 4x4
|
// Thread Tile size
|
||||||
constexpr int TN = 4;
|
constexpr int TN = 4;
|
||||||
constexpr int TM = 4;
|
constexpr int TM = 4;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -91,11 +91,11 @@ def hl_spec_kernel3():
|
||||||
sink = graph_rewrite(sink, merge_views)
|
sink = graph_rewrite(sink, merge_views)
|
||||||
return sink
|
return sink
|
||||||
|
|
||||||
def hand_spec_kernel3():
|
def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
||||||
BLOCK_SIZE = 256
|
BLOCK_SIZE = 128 if kernel5 else 256
|
||||||
|
|
||||||
nbWaves = BLOCK_SIZE // 32
|
nbWaves = BLOCK_SIZE // 32
|
||||||
WN = 64
|
WN = 128 if kernel5 else 64
|
||||||
WM = BN * BM // nbWaves // WN
|
WM = BN * BM // nbWaves // WN
|
||||||
|
|
||||||
nbWaveX = BN // WN
|
nbWaveX = BN // WN
|
||||||
|
|
@ -141,7 +141,8 @@ def hand_spec_kernel3():
|
||||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0)
|
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0)
|
||||||
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1)
|
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1)
|
||||||
|
|
||||||
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0)
|
BM_As_stride = (BM+4) if kernel5 else BM
|
||||||
|
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM_As_stride, AddrSpace.LOCAL), arg=0)
|
||||||
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1)
|
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1)
|
||||||
|
|
||||||
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), arg=2)
|
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), arg=2)
|
||||||
|
|
@ -149,51 +150,131 @@ def hand_spec_kernel3():
|
||||||
i = UOp.range(dtypes.int, c_regs.dtype.size, 16)
|
i = UOp.range(dtypes.int, c_regs.dtype.size, 16)
|
||||||
init_store = c_regs[i].store(UOp.const(dtypes.float, 0.0), i)
|
init_store = c_regs[i].store(UOp.const(dtypes.float, 0.0), i)
|
||||||
|
|
||||||
kId_range = UOp.range(dtypes.int, N//BK, 0)
|
if kernel4:
|
||||||
kId = kId_range*BK
|
regA = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbReadsA, AddrSpace.REG), arg=3)
|
||||||
|
regB = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbReadsB, AddrSpace.REG), arg=4)
|
||||||
|
|
||||||
# load from globals into locals
|
# initial load from globals into locals (0)
|
||||||
i = UOp.range(dtypes.int, nbReadsB, 1)
|
kId = 0
|
||||||
index_x = BN * blockIdx_x + rBIdx
|
|
||||||
index_y = rBIdy + i * strideReadB + kId
|
|
||||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
|
|
||||||
|
|
||||||
i = UOp.range(dtypes.int, nbReadsA, 2)
|
# load from globals into locals
|
||||||
index_x = rAIdx + kId
|
i = UOp.range(dtypes.int, nbReadsB, 0)
|
||||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
index_x = BN * blockIdx_x + rBIdx
|
||||||
As_store = As[(index_x % BK) * BM + index_y % BM].store(a[N * index_y + index_x].load(), i)
|
index_y = rBIdy + i * strideReadB + kId
|
||||||
|
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
|
||||||
|
|
||||||
barrier = UOp(Ops.BARRIER, src=(As_store, Bs_store))
|
i = UOp.range(dtypes.int, nbReadsA, 1)
|
||||||
|
index_x = rAIdx + kId
|
||||||
|
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||||
|
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i)
|
||||||
|
|
||||||
k = UOp.range(dtypes.int, BK, 3)
|
# iterate over the middle chunk
|
||||||
|
kId_range = UOp.range(dtypes.int, N//BK-1, 2)
|
||||||
|
kId = kId_range*BK
|
||||||
|
|
||||||
# load from locals into registers
|
barrier = UOp.barrier(As_store, Bs_store)
|
||||||
iterWave = UOp.range(dtypes.int, nbIterWaveN, 4)
|
|
||||||
i = UOp.range(dtypes.int, TN, 5)
|
|
||||||
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
|
||||||
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(barrier), iterWave, i)
|
|
||||||
|
|
||||||
iterWave = UOp.range(dtypes.int, nbIterWaveM, 6)
|
# load from globals into registers (next round)
|
||||||
i = UOp.range(dtypes.int, TM, 7)
|
i = UOp.range(dtypes.int, nbReadsB, 3)
|
||||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
index_x = BN * blockIdx_x + rBIdx
|
||||||
A_col_store = A_col[iterWave*TM + i].store(As[k*BM + index].load(barrier), iterWave, i)
|
index_y = rBIdy + i * strideReadB + kId + BK
|
||||||
|
regB_store = regB[i].store(b[N * index_y + index_x].load(), i)
|
||||||
|
|
||||||
# do the GEMM math
|
i = UOp.range(dtypes.int, nbReadsA, 4)
|
||||||
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8)
|
index_x = rAIdx + kId + BK
|
||||||
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 9)
|
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||||
yt = UOp.range(dtypes.int, TM, 10)
|
regA_store = regA[i].store(a[N * index_y + index_x].load(), i)
|
||||||
xt = UOp.range(dtypes.int, TN, 11)
|
|
||||||
x = iterWaveN * TN + xt
|
def inner_loop(first_range, inp_dep=()):
|
||||||
y = iterWaveM * TM + yt
|
# inner unroll
|
||||||
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
|
k = UOp.range(dtypes.int, BK, first_range+0)
|
||||||
sink = c_regs_idx.store(c_regs_idx.load(init_store) + A_col[y].load(A_col_store) * B_row[x].load(B_row_store),
|
|
||||||
iterWaveM, iterWaveN, yt, xt, k, kId_range)
|
# load from locals into registers
|
||||||
|
iterWave = UOp.range(dtypes.int, nbIterWaveN, first_range+1)
|
||||||
|
i = UOp.range(dtypes.int, TN, first_range+2)
|
||||||
|
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
||||||
|
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(*inp_dep), iterWave, i)
|
||||||
|
|
||||||
|
iterWave = UOp.range(dtypes.int, nbIterWaveM, first_range+3)
|
||||||
|
i = UOp.range(dtypes.int, TM, first_range+4)
|
||||||
|
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||||
|
A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(*inp_dep), iterWave, i)
|
||||||
|
|
||||||
|
# do the GEMM math
|
||||||
|
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, first_range+5)
|
||||||
|
yt = UOp.range(dtypes.int, TM, first_range+6)
|
||||||
|
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, first_range+7)
|
||||||
|
xt = UOp.range(dtypes.int, TN, first_range+8)
|
||||||
|
x = iterWaveN * TN + xt
|
||||||
|
y = iterWaveM * TM + yt
|
||||||
|
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
|
||||||
|
# sketchy, this should end the kId_range but it doesn't
|
||||||
|
sink = c_regs_idx.store(c_regs_idx.load(init_store) + A_col[y].load(A_col_store) * B_row[x].load(B_row_store),
|
||||||
|
iterWaveM, iterWaveN, yt, xt, k)
|
||||||
|
return sink
|
||||||
|
|
||||||
|
# TODO: kId_range should endrange after a barrier
|
||||||
|
sink = inner_loop(5, (barrier, regB_store, regA_store)).barrier()
|
||||||
|
|
||||||
|
# load from registers into locals
|
||||||
|
i = UOp.range(dtypes.int, nbReadsB, 14)
|
||||||
|
index_x = BN * blockIdx_x + rBIdx
|
||||||
|
index_y = rBIdy + i * strideReadB + kId + BK
|
||||||
|
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(regB[i].load(sink), i, kId_range)
|
||||||
|
|
||||||
|
i = UOp.range(dtypes.int, nbReadsA, 15)
|
||||||
|
index_x = rAIdx + kId + BK
|
||||||
|
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||||
|
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(regA[i].load(sink), i, kId_range)
|
||||||
|
|
||||||
|
# final iteration without the copy
|
||||||
|
sink = inner_loop(16, (UOp.barrier(Bs_store, As_store),))
|
||||||
|
else:
|
||||||
|
kId_range = UOp.range(dtypes.int, N//BK, 0)
|
||||||
|
kId = kId_range*BK
|
||||||
|
|
||||||
|
# load from globals into locals
|
||||||
|
i = UOp.range(dtypes.int, nbReadsB, 1)
|
||||||
|
index_x = BN * blockIdx_x + rBIdx
|
||||||
|
index_y = rBIdy + i * strideReadB + kId
|
||||||
|
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
|
||||||
|
|
||||||
|
i = UOp.range(dtypes.int, nbReadsA, 2)
|
||||||
|
index_x = rAIdx + kId
|
||||||
|
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||||
|
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i)
|
||||||
|
|
||||||
|
barrier = UOp.barrier(As_store, Bs_store)
|
||||||
|
|
||||||
|
k = UOp.range(dtypes.int, BK, 3)
|
||||||
|
|
||||||
|
# load from locals into registers
|
||||||
|
iterWave = UOp.range(dtypes.int, nbIterWaveN, 4)
|
||||||
|
i = UOp.range(dtypes.int, TN, 5)
|
||||||
|
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
||||||
|
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(barrier), iterWave, i)
|
||||||
|
|
||||||
|
iterWave = UOp.range(dtypes.int, nbIterWaveM, 6)
|
||||||
|
i = UOp.range(dtypes.int, TM, 7)
|
||||||
|
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||||
|
A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(barrier), iterWave, i)
|
||||||
|
|
||||||
|
# do the GEMM math
|
||||||
|
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8)
|
||||||
|
yt = UOp.range(dtypes.int, TM, 9)
|
||||||
|
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 10)
|
||||||
|
xt = UOp.range(dtypes.int, TN, 12)
|
||||||
|
x = iterWaveN * TN + xt
|
||||||
|
y = iterWaveM * TM + yt
|
||||||
|
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
|
||||||
|
sink = c_regs_idx.store(c_regs_idx.load(init_store) + A_col[y].load(A_col_store) * B_row[x].load(B_row_store),
|
||||||
|
iterWaveM, iterWaveN, yt, xt, k, kId_range)
|
||||||
|
|
||||||
# store c_regs into c
|
# store c_regs into c
|
||||||
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 12)
|
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 1000)
|
||||||
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 13)
|
yt = UOp.range(dtypes.int, TM, 1001)
|
||||||
yt = UOp.range(dtypes.int, TM, 14)
|
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 1002)
|
||||||
xt = UOp.range(dtypes.int, TN, 15)
|
xt = UOp.range(dtypes.int, TN, 1003)
|
||||||
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
|
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
|
||||||
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
|
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
|
||||||
indexC = N * (yOut + yt) + xOut + xt
|
indexC = N * (yOut + yt) + xOut + xt
|
||||||
|
|
|
||||||
|
|
@ -97,7 +97,7 @@ class BlockContext:
|
||||||
|
|
||||||
# ***** make blocks *****
|
# ***** make blocks *****
|
||||||
|
|
||||||
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}
|
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}
|
||||||
|
|
||||||
def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp, ...], cnt:int=1) -> UOp:
|
def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp, ...], cnt:int=1) -> UOp:
|
||||||
ends_to_add = [z for z in new_ctx if z not in current_ctx]
|
ends_to_add = [z for z in new_ctx if z not in current_ctx]
|
||||||
|
|
|
||||||
|
|
@ -238,7 +238,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||||
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
|
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
|
||||||
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
|
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
|
||||||
def barrier(self): return UOp(Ops.BARRIER, src=(self,))
|
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
|
||||||
def alu(self, op, *src:UOp, **kwargs):
|
def alu(self, op, *src:UOp, **kwargs):
|
||||||
out_dtype = (self, *src)[-1].dtype
|
out_dtype = (self, *src)[-1].dtype
|
||||||
if op in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
if op in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue