mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
270 lines
10 KiB
C++
270 lines
10 KiB
C++
#include "kittens.cuh"
|
|
|
|
using namespace kittens;
|
|
|
|
#ifndef GEMM_M
|
|
constexpr int GEMM_M = 8192;
|
|
#endif
|
|
#ifndef GEMM_N
|
|
constexpr int GEMM_N = 8192;
|
|
#endif
|
|
#ifndef GEMM_K
|
|
constexpr int GEMM_K = 8192;
|
|
#endif
|
|
|
|
// Kernel
|
|
constexpr int NUM_WARPS = 8;
|
|
constexpr int WARPS_ROW = 2;
|
|
constexpr int WARPS_COL = 4;
|
|
constexpr int BLOCK_ROW = 256;
|
|
constexpr int BLOCK_COL = 256;
|
|
constexpr int BLOCK_K = 128;
|
|
constexpr int HALF_ROW = BLOCK_ROW / 2;
|
|
constexpr int HALF_COL = BLOCK_COL / 2;
|
|
constexpr int REG_M = BLOCK_ROW / WARPS_ROW / 2;
|
|
constexpr int REG_N = BLOCK_COL / WARPS_COL / 2;
|
|
|
|
using G = kittens::group<NUM_WARPS>;
|
|
|
|
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr,
|
|
const uint32_t *__restrict__ scale_A_iter,
|
|
const uint32_t *__restrict__ scale_B_iter,
|
|
const uint8_t *__restrict__ a_e8_unused,
|
|
const uint8_t *__restrict__ b_e8_unused) {
|
|
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
|
|
|
|
kittens::gl<fp8e4m3, 1, 1, M, K> A{A_ptr, nullptr, nullptr, nullptr, nullptr};
|
|
kittens::gl<fp8e4m3, 1, 1, N, K> B{B_ptr, nullptr, nullptr, nullptr, nullptr};
|
|
kittens::gl<bf16, 1, 1, M, N> C{C_ptr, nullptr, nullptr, nullptr, nullptr};
|
|
|
|
constexpr int k_iters = K / BLOCK_K;
|
|
constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS;
|
|
|
|
using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>;
|
|
using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>;
|
|
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
|
|
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>;
|
|
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>;
|
|
|
|
__shared__ ST_A As[2][2];
|
|
__shared__ ST_B Bs[2][2];
|
|
__shared__ uint8_t smem_scales[2048];
|
|
|
|
RT_A a;
|
|
RT_B b0, b1;
|
|
RT_C cA, cB, cC, cD;
|
|
zero(cA); zero(cB); zero(cC); zero(cD);
|
|
|
|
constexpr int tiles_M = M / BLOCK_ROW;
|
|
constexpr int tiles_N = N / BLOCK_COL;
|
|
const int NUM_XCDS = 8;
|
|
const int WGM = 8;
|
|
int wgid = chiplet_transform_chunked(blockIdx.x, gridDim.x, NUM_XCDS, WGM * WGM);
|
|
int num_wgid_in_group = WGM * tiles_N;
|
|
int group_id = wgid / num_wgid_in_group;
|
|
int first_pid_m = group_id * WGM;
|
|
int group_size_m = min(tiles_M - first_pid_m, WGM);
|
|
int block_row = first_pid_m + ((wgid % num_wgid_in_group) % group_size_m);
|
|
int block_col = (wgid % num_wgid_in_group) / group_size_m;
|
|
int block_m = block_row * BLOCK_ROW;
|
|
int block_n = block_col * BLOCK_COL;
|
|
|
|
int warp_m = warpid() / WARPS_COL;
|
|
int warp_n = warpid() % WARPS_COL;
|
|
|
|
using T = fp8e4m3;
|
|
constexpr int bpt = ST_A::underlying_subtile_bytes_per_thread;
|
|
constexpr int bpm = bpt * NUM_THREADS;
|
|
constexpr int copies_A = HALF_ROW * BLOCK_K * sizeof(T) / bpm;
|
|
constexpr int copies_B = HALF_COL * BLOCK_K * sizeof(T) / bpm;
|
|
uint32_t sw_A[copies_A], sw_B[copies_B];
|
|
G::prefill_swizzled_offsets(As[0][0], A, sw_A);
|
|
G::prefill_swizzled_offsets(Bs[0][0], B, sw_B);
|
|
|
|
int a_row_h0 = warp_m * REG_M;
|
|
int a_row_h1 = HALF_ROW + warp_m * REG_M;
|
|
int b_row_h0 = warp_n * REG_N;
|
|
int b_row_h1 = HALF_COL + warp_n * REG_N;
|
|
|
|
int tic = 0, toc = 1;
|
|
|
|
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B);
|
|
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A);
|
|
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B);
|
|
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A);
|
|
|
|
if (warp_m == 1) __builtin_amdgcn_s_barrier();
|
|
asm volatile("s_waitcnt vmcnt(4)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A);
|
|
G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B);
|
|
G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B);
|
|
|
|
asm volatile("s_waitcnt vmcnt(6)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
#pragma unroll 2
|
|
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1) {
|
|
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
|
|
|
|
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
|
|
load(b0, bs0);
|
|
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
|
load(a, as0);
|
|
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
|
|
asm volatile("s_waitcnt lgkmcnt(8)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
|
|
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
|
|
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
__builtin_amdgcn_sched_barrier(0);
|
|
|
|
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
|
|
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
|
|
load(b1, bs1);
|
|
G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A);
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
|
|
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
|
|
load(a, as1);
|
|
G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B);
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
__builtin_amdgcn_sched_barrier(0);
|
|
|
|
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B);
|
|
asm volatile("s_waitcnt vmcnt(6)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
}
|
|
|
|
{ // Epilogue k = k_iters - 2
|
|
int k = k_iters - 2;
|
|
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
|
|
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
|
|
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
|
|
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
|
|
|
|
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
|
|
load(b0, bs0);
|
|
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
|
load(a, as0);
|
|
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
__builtin_amdgcn_sched_barrier(0);
|
|
|
|
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
|
|
load(b1, bs1);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
|
|
load(a, as1);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
auto bs0_next = subtile_inplace<REG_N, BLOCK_K>(Bs[toc][0], {warp_n, 0});
|
|
load(b0, bs0_next);
|
|
asm volatile("s_waitcnt vmcnt(4)");
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
__builtin_amdgcn_sched_barrier(0);
|
|
|
|
tic ^= 1; toc ^= 1;
|
|
}
|
|
|
|
{ // Epilogue k = k_iters - 1 (final)
|
|
int k = k_iters - 1;
|
|
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
|
|
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
|
|
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
|
|
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
|
|
|
|
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
|
load(a, as0);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
|
|
load(b1, bs1);
|
|
asm volatile("s_waitcnt vmcnt(0)");
|
|
__builtin_amdgcn_s_barrier();
|
|
__builtin_amdgcn_sched_barrier(0);
|
|
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
|
|
load(a, as1);
|
|
__builtin_amdgcn_s_barrier();
|
|
|
|
asm volatile("s_waitcnt lgkmcnt(0)");
|
|
__builtin_amdgcn_s_setprio(2);
|
|
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
|
|
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
|
|
__builtin_amdgcn_s_setprio(0);
|
|
__builtin_amdgcn_s_barrier();
|
|
}
|
|
|
|
store(C, cA, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + warp_n});
|
|
store(C, cB, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n});
|
|
store(C, cC, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + warp_n});
|
|
store(C, cD, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n});
|
|
}
|