#include "kittens.cuh" using namespace kittens; #ifndef GEMM_M constexpr int GEMM_M = 8192; #endif #ifndef GEMM_N constexpr int GEMM_N = 8192; #endif #ifndef GEMM_K constexpr int GEMM_K = 8192; #endif // Kernel constexpr int NUM_WARPS = 8; constexpr int WARPS_ROW = 2; constexpr int WARPS_COL = 4; constexpr int BLOCK_ROW = 256; constexpr int BLOCK_COL = 256; constexpr int BLOCK_K = 128; constexpr int HALF_ROW = BLOCK_ROW / 2; constexpr int HALF_COL = BLOCK_COL / 2; constexpr int REG_M = BLOCK_ROW / WARPS_ROW / 2; constexpr int REG_N = BLOCK_COL / WARPS_COL / 2; using G = kittens::group; __global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr, fp8e8m0 *scale_A_ptr, fp8e8m0 *scale_B_ptr, const uint8_t *__restrict__ a_e8_unused, const uint8_t *__restrict__ b_e8_unused) { constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K; kittens::gl A{A_ptr, nullptr, nullptr, nullptr, nullptr}; kittens::gl B{B_ptr, nullptr, nullptr, nullptr, nullptr}; kittens::gl C{C_ptr, nullptr, nullptr, nullptr, nullptr}; constexpr int k_iters = K / BLOCK_K; constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS; kittens::gl scale_A_gl{scale_A_ptr, nullptr, nullptr, nullptr, nullptr}; kittens::gl scale_B_gl{scale_B_ptr, nullptr, nullptr, nullptr, nullptr}; using ST_A = st_fp8e4m3; using ST_B = st_fp8e4m3; using ST_Scale = st; using RT_A = rt_fp8e4m3; using RT_B = rt_fp8e4m3; using RT_C = rt_fl; __shared__ ST_A As[2][2]; __shared__ ST_B Bs[2][2]; __shared__ ST_Scale scale_A_smem[2], scale_B_smem[2]; RT_A a; RT_B b0, b1; RT_C cA, cB, cC, cD; zero(cA); zero(cB); zero(cC); zero(cD); constexpr int tiles_M = M / BLOCK_ROW; constexpr int tiles_N = N / BLOCK_COL; const int NUM_XCDS = 8; const int WGM = 8; int wgid = chiplet_transform_chunked(blockIdx.x, gridDim.x, NUM_XCDS, WGM * WGM); int num_wgid_in_group = WGM * tiles_N; int group_id = wgid / num_wgid_in_group; int first_pid_m = group_id * WGM; int group_size_m = min(tiles_M - first_pid_m, WGM); int block_row = first_pid_m + ((wgid % num_wgid_in_group) % group_size_m); int block_col = (wgid % num_wgid_in_group) / group_size_m; int block_m = block_row * BLOCK_ROW; int block_n = block_col * BLOCK_COL; int warp_m = warpid() / WARPS_COL; int warp_n = warpid() % WARPS_COL; using T = fp8e4m3; constexpr int bpt = ST_A::underlying_subtile_bytes_per_thread; constexpr int bpm = bpt * NUM_THREADS; constexpr int copies_A = HALF_ROW * BLOCK_K * sizeof(T) / bpm; constexpr int copies_B = HALF_COL * BLOCK_K * sizeof(T) / bpm; uint32_t sw_A[copies_A], sw_B[copies_B]; G::prefill_swizzled_offsets(As[0][0], A, sw_A); G::prefill_swizzled_offsets(Bs[0][0], B, sw_B); const T *a_base = (const T *)&A[{0, 0, 0, 0}]; const T *b_base = (const T *)&B[{0, 0, 0, 0}]; const int a_row_stride = A.template stride<2>() * sizeof(T); const int b_row_stride = B.template stride<2>() * sizeof(T); i32x4 a_srd = make_srsrc(a_base, M * a_row_stride, a_row_stride); i32x4 b_srd = make_srsrc(b_base, N * b_row_stride, b_row_stride); const int wid = warpid() % NUM_WARPS; constexpr int elem_per_warp = (16 / sizeof(T)) * kittens::WARP_THREADS; uint32_t a_lds_00 = __builtin_amdgcn_readfirstlane(static_cast(reinterpret_cast(&As[0][0].data[0]) + wid * elem_per_warp * sizeof(T))); uint32_t a_lds_01 = __builtin_amdgcn_readfirstlane(static_cast(reinterpret_cast(&As[0][1].data[0]) + wid * elem_per_warp * sizeof(T))); uint32_t a_lds_10 = __builtin_amdgcn_readfirstlane(static_cast(reinterpret_cast(&As[1][0].data[0]) + wid * elem_per_warp * sizeof(T))); uint32_t a_lds_11 = __builtin_amdgcn_readfirstlane(static_cast(reinterpret_cast(&As[1][1].data[0]) + wid * elem_per_warp * sizeof(T))); uint32_t b_lds_00 = __builtin_amdgcn_readfirstlane(static_cast(reinterpret_cast(&Bs[0][0].data[0]) + wid * elem_per_warp * sizeof(T))); uint32_t b_lds_01 = __builtin_amdgcn_readfirstlane(static_cast(reinterpret_cast(&Bs[0][1].data[0]) + wid * elem_per_warp * sizeof(T))); uint32_t b_lds_10 = __builtin_amdgcn_readfirstlane(static_cast(reinterpret_cast(&Bs[1][0].data[0]) + wid * elem_per_warp * sizeof(T))); uint32_t b_lds_11 = __builtin_amdgcn_readfirstlane(static_cast(reinterpret_cast(&Bs[1][1].data[0]) + wid * elem_per_warp * sizeof(T))); int a_row_h0 = warp_m * REG_M; int a_row_h1 = HALF_ROW + warp_m * REG_M; int b_row_h0 = warp_n * REG_N; int b_row_h1 = HALF_COL + warp_n * REG_N; int tic = 0, toc = 1; int tic_scales = 0, toc_scales = 1; uint32_t a_lds[2][2] = {{a_lds_00, a_lds_01}, {a_lds_10, a_lds_11}}; uint32_t b_lds[2][2] = {{b_lds_00, b_lds_01}, {b_lds_10, b_lds_11}}; if constexpr (k_iters >= 6 && (k_iters % 2 == 0)) { G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0])); G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0])); G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1])); G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][1])); if (warp_m == 1) __builtin_amdgcn_s_barrier(); asm volatile("s_waitcnt vmcnt(4)"); __builtin_amdgcn_s_barrier(); G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][0])); G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][0])); G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[toc][1])); asm volatile("s_waitcnt vmcnt(6)"); __builtin_amdgcn_s_barrier(); G::load(scale_A_smem[0], scale_A_gl, {0 * tiles_M + block_row, 0, 0, 0}); G::load(scale_B_smem[0], scale_B_gl, {0 * tiles_N + block_col, 0, 0, 0}); asm volatile("s_waitcnt vmcnt(0)"); asm volatile("s_waitcnt lgkmcnt(0)"); __builtin_amdgcn_s_barrier(); int sa_idx = block_row, sb_idx = block_col; #pragma unroll 2 for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1, tic_scales ^= 1, toc_scales ^= 1) { if (k + 1 < k_iters) { sa_idx += tiles_M; sb_idx += tiles_N; G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0}); G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0}); } auto bs0 = subtile_inplace(Bs[tic][0], {warp_n, 0}); load(b0, bs0); auto as0 = subtile_inplace(As[tic][0], {warp_m, 0}); load(a, as0); G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1])); asm volatile("s_waitcnt lgkmcnt(8)"); __builtin_amdgcn_s_barrier(); fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0); fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0); fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1); fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1); __builtin_amdgcn_s_setprio(2); mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0); __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); auto bs1 = subtile_inplace(Bs[tic][1], {warp_n, 0}); load(b1, bs1); G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[tic][0])); asm volatile("s_waitcnt lgkmcnt(0)"); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_setprio(2); mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1); __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_barrier(); auto as1 = subtile_inplace(As[tic][1], {warp_m, 0}); load(a, as1); G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][0])); asm volatile("s_waitcnt lgkmcnt(0)"); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_setprio(2); mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0); __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[tic][1])); asm volatile("s_waitcnt vmcnt(6)"); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_setprio(2); mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1); __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_barrier(); } { // Epilogue k = k_iters - 2 int k = k_iters - 2; if (k + 1 < k_iters) { sa_idx += tiles_M; sb_idx += tiles_N; G::load(scale_A_smem[toc_scales], scale_A_gl, {sa_idx, 0, 0, 0}); G::load(scale_B_smem[toc_scales], scale_B_gl, {sb_idx, 0, 0, 0}); } asm volatile("s_waitcnt vmcnt(0)"); asm volatile("s_waitcnt lgkmcnt(0)"); __builtin_amdgcn_s_barrier(); fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0); fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1); fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0); fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1); auto bs0 = subtile_inplace(Bs[tic][0], {warp_n, 0}); load(b0, bs0); auto as0 = subtile_inplace(As[tic][0], {warp_m, 0}); load(a, as0); G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[toc][1])); __builtin_amdgcn_s_barrier(); asm volatile("s_waitcnt lgkmcnt(0)"); __builtin_amdgcn_s_setprio(2); mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0); __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); auto bs1 = subtile_inplace(Bs[tic][1], {warp_n, 0}); load(b1, bs1); __builtin_amdgcn_s_barrier(); asm volatile("s_waitcnt lgkmcnt(0)"); __builtin_amdgcn_s_setprio(2); mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1); __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_barrier(); auto as1 = subtile_inplace(As[tic][1], {warp_m, 0}); load(a, as1); __builtin_amdgcn_s_barrier(); asm volatile("s_waitcnt lgkmcnt(0)"); __builtin_amdgcn_s_setprio(2); mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0); __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_barrier(); auto bs0_next = subtile_inplace(Bs[toc][0], {warp_n, 0}); load(b0, bs0_next); asm volatile("s_waitcnt vmcnt(4)"); __builtin_amdgcn_s_barrier(); asm volatile("s_waitcnt lgkmcnt(0)"); __builtin_amdgcn_s_setprio(2); mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1); __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); tic ^= 1; toc ^= 1; tic_scales ^= 1; toc_scales ^= 1; } { // Epilogue k = k_iters - 1 (final) asm volatile("s_waitcnt vmcnt(0)"); asm volatile("s_waitcnt lgkmcnt(0)"); __builtin_amdgcn_s_barrier(); fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[tic_scales].data, a_row_h0); fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[tic_scales].data, a_row_h1); fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[tic_scales].data, b_row_h0); fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[tic_scales].data, b_row_h1); auto as0 = subtile_inplace(As[tic][0], {warp_m, 0}); load(a, as0); __builtin_amdgcn_s_barrier(); asm volatile("s_waitcnt lgkmcnt(0)"); __builtin_amdgcn_s_setprio(2); mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0); __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_barrier(); auto bs1 = subtile_inplace(Bs[tic][1], {warp_n, 0}); load(b1, bs1); asm volatile("s_waitcnt vmcnt(0)"); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt lgkmcnt(0)"); __builtin_amdgcn_s_setprio(2); mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1); __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_barrier(); auto as1 = subtile_inplace(As[tic][1], {warp_m, 0}); load(a, as1); __builtin_amdgcn_s_barrier(); asm volatile("s_waitcnt lgkmcnt(0)"); __builtin_amdgcn_s_setprio(2); mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0); mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1); __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_barrier(); } } else { #pragma unroll 1 for (int kk = 0; kk < k_iters; kk++) { G::load(As[0][0], A, {0, 0, block_row * 2, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][0])); G::load(As[0][1], A, {0, 0, block_row * 2 + 1, kk}, sw_A, a_srd, a_base, __builtin_amdgcn_readfirstlane(a_lds[0][1])); G::load(Bs[0][0], B, {0, 0, block_col * 2, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][0])); G::load(Bs[0][1], B, {0, 0, block_col * 2 + 1, kk}, sw_B, b_srd, b_base, __builtin_amdgcn_readfirstlane(b_lds[0][1])); G::load(scale_A_smem[0], scale_A_gl, {kk * tiles_M + block_row, 0, 0, 0}); G::load(scale_B_smem[0], scale_B_gl, {kk * tiles_N + block_col, 0, 0, 0}); asm volatile("s_waitcnt vmcnt(0)"); asm volatile("s_waitcnt lgkmcnt(0)"); __builtin_amdgcn_s_barrier(); fp8e8m0_4 sa_h0 = pack_scales(scale_A_smem[0].data, a_row_h0); fp8e8m0_4 sa_h1 = pack_scales(scale_A_smem[0].data, a_row_h1); fp8e8m0_4 sb_h0 = pack_scales(scale_B_smem[0].data, b_row_h0); fp8e8m0_4 sb_h1 = pack_scales(scale_B_smem[0].data, b_row_h1); auto bs0 = subtile_inplace(Bs[0][0], {warp_n, 0}); load(b0, bs0); auto bs1 = subtile_inplace(Bs[0][1], {warp_n, 0}); load(b1, bs1); auto as0 = subtile_inplace(As[0][0], {warp_m, 0}); load(a, as0); asm volatile("s_waitcnt lgkmcnt(0)"); mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0); mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1); auto as1 = subtile_inplace(As[0][1], {warp_m, 0}); load(a, as1); asm volatile("s_waitcnt lgkmcnt(0)"); mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0); mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1); __builtin_amdgcn_s_barrier(); } } store(C, cA, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + warp_n}); store(C, cB, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n}); store(C, cC, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + warp_n}); store(C, cD, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n}); }