mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
gemm: mxfp8 hipkittens gemm (#16541)
* gemm: mxfp8 hipkittens gemm * feat: update hipkittens * feat: kernel signature * clean: just kernel * feat: from tinygrad * feat: test * fix: add back utils * clean: no diff * clean: no diff
This commit is contained in:
parent
12addee14f
commit
2bdc360606
3 changed files with 337 additions and 0 deletions
|
|
@ -2649,6 +2649,27 @@ def custom_hk_fp8_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str, scale_mode:int
|
|||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src),
|
||||
UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
# ** MXFP8 GEMM custom kernel
|
||||
|
||||
@functools.cache
|
||||
def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *, dname:str) -> UOp:
|
||||
# mxfp8 block-scaled gemm: A(M,K) @ B(N,K).T, e8m0 1x32 microscales packed (k_iters,dim) uint32
|
||||
M, K = A.shape[0]*A.shape[1], A.shape[2]
|
||||
N, K2 = B.shape[(1 if B.ndim == 3 else 0):]
|
||||
assert K == K2, f"{A.shape} {B.shape}"
|
||||
block_size = 256
|
||||
threads = UOp.special(64 * 8, "lidx0")
|
||||
workgroups = UOp.special((M // block_size) * (N // block_size), "gidx0")
|
||||
sink_inputs = (C.base, A.base, B.base, scale_A.base, scale_B.base, threads, workgroups)
|
||||
sink = UOp.sink(*sink_inputs,
|
||||
arg=KernelInfo(f"hk_mxfp8_gemm_{M}_{N}_{K}", estimates=Estimates(ops=2*M*N*K, mem=(M*K+N*K)*A.dtype.itemsize+M*N*C.dtype.itemsize)))
|
||||
kittens_path = pathlib.Path(__file__).parent.parent/"thunder"/"amd"
|
||||
src = (kittens_path/"gemm_mxfp8.cpp").read_text()
|
||||
lib = HIPCCCompiler("gfx950", [f"-I{(kittens_path/'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-ffast-math",
|
||||
"-DHIP_ENABLE_WARP_SYNC_BUILTINS", f"-DGEMM_M={M}", f"-DGEMM_N={N}", f"-DGEMM_K={K}"]).compile_cached(src)
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src),
|
||||
UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
counters = {"used":0, "todos":[]}
|
||||
def todo(msg:str) -> bool: counters["todos"].append(msg); return False
|
||||
def _asm_gemm_report():
|
||||
|
|
|
|||
268
extra/thunder/amd/gemm_mxfp8.cpp
Normal file
268
extra/thunder/amd/gemm_mxfp8.cpp
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
#include "kittens.cuh"
|
||||
|
||||
using namespace kittens;
|
||||
|
||||
#ifndef GEMM_M
|
||||
constexpr int GEMM_M = 8192;
|
||||
#endif
|
||||
#ifndef GEMM_N
|
||||
constexpr int GEMM_N = 8192;
|
||||
#endif
|
||||
#ifndef GEMM_K
|
||||
constexpr int GEMM_K = 8192;
|
||||
#endif
|
||||
|
||||
// Kernel
|
||||
constexpr int NUM_WARPS = 8;
|
||||
constexpr int WARPS_ROW = 2;
|
||||
constexpr int WARPS_COL = 4;
|
||||
constexpr int BLOCK_ROW = 256;
|
||||
constexpr int BLOCK_COL = 256;
|
||||
constexpr int BLOCK_K = 128;
|
||||
constexpr int HALF_ROW = BLOCK_ROW / 2;
|
||||
constexpr int HALF_COL = BLOCK_COL / 2;
|
||||
constexpr int REG_M = BLOCK_ROW / WARPS_ROW / 2;
|
||||
constexpr int REG_N = BLOCK_COL / WARPS_COL / 2;
|
||||
|
||||
using G = kittens::group<NUM_WARPS>;
|
||||
|
||||
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr,
|
||||
const uint32_t *__restrict__ scale_A_iter,
|
||||
const uint32_t *__restrict__ scale_B_iter) {
|
||||
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
|
||||
|
||||
kittens::gl<fp8e4m3, 1, 1, M, K> A{A_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
kittens::gl<fp8e4m3, 1, 1, N, K> B{B_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
kittens::gl<bf16, 1, 1, M, N> C{C_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
constexpr int k_iters = K / BLOCK_K;
|
||||
constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS;
|
||||
|
||||
using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>;
|
||||
using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>;
|
||||
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
|
||||
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>;
|
||||
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>;
|
||||
|
||||
__shared__ ST_A As[2][2];
|
||||
__shared__ ST_B Bs[2][2];
|
||||
__shared__ uint8_t smem_scales[2048];
|
||||
|
||||
RT_A a;
|
||||
RT_B b0, b1;
|
||||
RT_C cA, cB, cC, cD;
|
||||
zero(cA); zero(cB); zero(cC); zero(cD);
|
||||
|
||||
constexpr int tiles_M = M / BLOCK_ROW;
|
||||
constexpr int tiles_N = N / BLOCK_COL;
|
||||
const int NUM_XCDS = 8;
|
||||
const int WGM = 8;
|
||||
int wgid = chiplet_transform_chunked(blockIdx.x, gridDim.x, NUM_XCDS, WGM * WGM);
|
||||
int num_wgid_in_group = WGM * tiles_N;
|
||||
int group_id = wgid / num_wgid_in_group;
|
||||
int first_pid_m = group_id * WGM;
|
||||
int group_size_m = min(tiles_M - first_pid_m, WGM);
|
||||
int block_row = first_pid_m + ((wgid % num_wgid_in_group) % group_size_m);
|
||||
int block_col = (wgid % num_wgid_in_group) / group_size_m;
|
||||
int block_m = block_row * BLOCK_ROW;
|
||||
int block_n = block_col * BLOCK_COL;
|
||||
|
||||
int warp_m = warpid() / WARPS_COL;
|
||||
int warp_n = warpid() % WARPS_COL;
|
||||
|
||||
using T = fp8e4m3;
|
||||
constexpr int bpt = ST_A::underlying_subtile_bytes_per_thread;
|
||||
constexpr int bpm = bpt * NUM_THREADS;
|
||||
constexpr int copies_A = HALF_ROW * BLOCK_K * sizeof(T) / bpm;
|
||||
constexpr int copies_B = HALF_COL * BLOCK_K * sizeof(T) / bpm;
|
||||
uint32_t sw_A[copies_A], sw_B[copies_B];
|
||||
G::prefill_swizzled_offsets(As[0][0], A, sw_A);
|
||||
G::prefill_swizzled_offsets(Bs[0][0], B, sw_B);
|
||||
|
||||
int a_row_h0 = warp_m * REG_M;
|
||||
int a_row_h1 = HALF_ROW + warp_m * REG_M;
|
||||
int b_row_h0 = warp_n * REG_N;
|
||||
int b_row_h1 = HALF_COL + warp_n * REG_N;
|
||||
|
||||
int tic = 0, toc = 1;
|
||||
|
||||
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B);
|
||||
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A);
|
||||
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B);
|
||||
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A);
|
||||
|
||||
if (warp_m == 1) __builtin_amdgcn_s_barrier();
|
||||
asm volatile("s_waitcnt vmcnt(4)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A);
|
||||
G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B);
|
||||
G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B);
|
||||
|
||||
asm volatile("s_waitcnt vmcnt(6)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
#pragma unroll 2
|
||||
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1) {
|
||||
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
|
||||
|
||||
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
|
||||
load(b0, bs0);
|
||||
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
||||
load(a, as0);
|
||||
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
|
||||
asm volatile("s_waitcnt lgkmcnt(8)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
|
||||
|
||||
__builtin_amdgcn_s_setprio(2);
|
||||
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
|
||||
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
|
||||
load(b1, bs1);
|
||||
G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A);
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
__builtin_amdgcn_s_setprio(2);
|
||||
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
|
||||
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
|
||||
load(a, as1);
|
||||
G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B);
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
__builtin_amdgcn_s_setprio(2);
|
||||
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B);
|
||||
asm volatile("s_waitcnt vmcnt(6)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
__builtin_amdgcn_s_setprio(2);
|
||||
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
{ // Epilogue k = k_iters - 2
|
||||
int k = k_iters - 2;
|
||||
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
|
||||
|
||||
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
|
||||
load(b0, bs0);
|
||||
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
||||
load(a, as0);
|
||||
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_setprio(2);
|
||||
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
|
||||
load(b1, bs1);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_setprio(2);
|
||||
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
|
||||
load(a, as1);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_setprio(2);
|
||||
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
auto bs0_next = subtile_inplace<REG_N, BLOCK_K>(Bs[toc][0], {warp_n, 0});
|
||||
load(b0, bs0_next);
|
||||
asm volatile("s_waitcnt vmcnt(4)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_setprio(2);
|
||||
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
tic ^= 1; toc ^= 1;
|
||||
}
|
||||
|
||||
{ // Epilogue k = k_iters - 1 (final)
|
||||
int k = k_iters - 1;
|
||||
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
|
||||
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
|
||||
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
|
||||
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
|
||||
|
||||
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
|
||||
load(a, as0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_setprio(2);
|
||||
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
|
||||
load(b1, bs1);
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_setprio(2);
|
||||
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
|
||||
load(a, as1);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
asm volatile("s_waitcnt lgkmcnt(0)");
|
||||
__builtin_amdgcn_s_setprio(2);
|
||||
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
|
||||
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
store(C, cA, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + warp_n});
|
||||
store(C, cB, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n});
|
||||
store(C, cC, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + warp_n});
|
||||
store(C, cD, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n});
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue