gemm: mxfp8 hipkittens gemm (#16541)

* gemm: mxfp8 hipkittens gemm

* feat: update hipkittens

* feat: kernel signature

* clean: just kernel

* feat: from tinygrad

* feat: test

* fix: add back utils

* clean: no diff

* clean: no diff
This commit is contained in:
wozeparrot 2026-06-09 18:20:05 -04:00 committed by GitHub
commit 2bdc360606
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 337 additions and 0 deletions

View file

@ -2649,6 +2649,27 @@ def custom_hk_fp8_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str, scale_mode:int
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src),
UOp(Ops.BINARY, arg=lib)))
# ** MXFP8 GEMM custom kernel
@functools.cache
def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *, dname:str) -> UOp:
# mxfp8 block-scaled gemm: A(M,K) @ B(N,K).T, e8m0 1x32 microscales packed (k_iters,dim) uint32
M, K = A.shape[0]*A.shape[1], A.shape[2]
N, K2 = B.shape[(1 if B.ndim == 3 else 0):]
assert K == K2, f"{A.shape} {B.shape}"
block_size = 256
threads = UOp.special(64 * 8, "lidx0")
workgroups = UOp.special((M // block_size) * (N // block_size), "gidx0")
sink_inputs = (C.base, A.base, B.base, scale_A.base, scale_B.base, threads, workgroups)
sink = UOp.sink(*sink_inputs,
arg=KernelInfo(f"hk_mxfp8_gemm_{M}_{N}_{K}", estimates=Estimates(ops=2*M*N*K, mem=(M*K+N*K)*A.dtype.itemsize+M*N*C.dtype.itemsize)))
kittens_path = pathlib.Path(__file__).parent.parent/"thunder"/"amd"
src = (kittens_path/"gemm_mxfp8.cpp").read_text()
lib = HIPCCCompiler("gfx950", [f"-I{(kittens_path/'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-ffast-math",
"-DHIP_ENABLE_WARP_SYNC_BUILTINS", f"-DGEMM_M={M}", f"-DGEMM_N={N}", f"-DGEMM_K={K}"]).compile_cached(src)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src),
UOp(Ops.BINARY, arg=lib)))
counters = {"used":0, "todos":[]}
def todo(msg:str) -> bool: counters["todos"].append(msg); return False
def _asm_gemm_report():

View file

@ -0,0 +1,268 @@
#include "kittens.cuh"
using namespace kittens;
#ifndef GEMM_M
constexpr int GEMM_M = 8192;
#endif
#ifndef GEMM_N
constexpr int GEMM_N = 8192;
#endif
#ifndef GEMM_K
constexpr int GEMM_K = 8192;
#endif
// Kernel
constexpr int NUM_WARPS = 8;
constexpr int WARPS_ROW = 2;
constexpr int WARPS_COL = 4;
constexpr int BLOCK_ROW = 256;
constexpr int BLOCK_COL = 256;
constexpr int BLOCK_K = 128;
constexpr int HALF_ROW = BLOCK_ROW / 2;
constexpr int HALF_COL = BLOCK_COL / 2;
constexpr int REG_M = BLOCK_ROW / WARPS_ROW / 2;
constexpr int REG_N = BLOCK_COL / WARPS_COL / 2;
using G = kittens::group<NUM_WARPS>;
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr,
const uint32_t *__restrict__ scale_A_iter,
const uint32_t *__restrict__ scale_B_iter) {
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
kittens::gl<fp8e4m3, 1, 1, M, K> A{A_ptr, nullptr, nullptr, nullptr, nullptr};
kittens::gl<fp8e4m3, 1, 1, N, K> B{B_ptr, nullptr, nullptr, nullptr, nullptr};
kittens::gl<bf16, 1, 1, M, N> C{C_ptr, nullptr, nullptr, nullptr, nullptr};
constexpr int k_iters = K / BLOCK_K;
constexpr int NUM_THREADS = NUM_WARPS * WARP_THREADS;
using ST_A = st_fp8e4m3<HALF_ROW, BLOCK_K, st_16x128_s>;
using ST_B = st_fp8e4m3<HALF_COL, BLOCK_K, st_16x128_s>;
using RT_A = rt_fp8e4m3<REG_M, BLOCK_K>;
using RT_B = rt_fp8e4m3<REG_N, BLOCK_K>;
using RT_C = rt_fl<REG_M, REG_N, col_l, rt_16x16_s>;
__shared__ ST_A As[2][2];
__shared__ ST_B Bs[2][2];
__shared__ uint8_t smem_scales[2048];
RT_A a;
RT_B b0, b1;
RT_C cA, cB, cC, cD;
zero(cA); zero(cB); zero(cC); zero(cD);
constexpr int tiles_M = M / BLOCK_ROW;
constexpr int tiles_N = N / BLOCK_COL;
const int NUM_XCDS = 8;
const int WGM = 8;
int wgid = chiplet_transform_chunked(blockIdx.x, gridDim.x, NUM_XCDS, WGM * WGM);
int num_wgid_in_group = WGM * tiles_N;
int group_id = wgid / num_wgid_in_group;
int first_pid_m = group_id * WGM;
int group_size_m = min(tiles_M - first_pid_m, WGM);
int block_row = first_pid_m + ((wgid % num_wgid_in_group) % group_size_m);
int block_col = (wgid % num_wgid_in_group) / group_size_m;
int block_m = block_row * BLOCK_ROW;
int block_n = block_col * BLOCK_COL;
int warp_m = warpid() / WARPS_COL;
int warp_n = warpid() % WARPS_COL;
using T = fp8e4m3;
constexpr int bpt = ST_A::underlying_subtile_bytes_per_thread;
constexpr int bpm = bpt * NUM_THREADS;
constexpr int copies_A = HALF_ROW * BLOCK_K * sizeof(T) / bpm;
constexpr int copies_B = HALF_COL * BLOCK_K * sizeof(T) / bpm;
uint32_t sw_A[copies_A], sw_B[copies_B];
G::prefill_swizzled_offsets(As[0][0], A, sw_A);
G::prefill_swizzled_offsets(Bs[0][0], B, sw_B);
int a_row_h0 = warp_m * REG_M;
int a_row_h1 = HALF_ROW + warp_m * REG_M;
int b_row_h0 = warp_n * REG_N;
int b_row_h1 = HALF_COL + warp_n * REG_N;
int tic = 0, toc = 1;
G::load(Bs[tic][0], B, {0, 0, block_col * 2, 0}, sw_B);
G::load(As[tic][0], A, {0, 0, block_row * 2, 0}, sw_A);
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, 0}, sw_B);
G::load(As[tic][1], A, {0, 0, block_row * 2 + 1, 0}, sw_A);
if (warp_m == 1) __builtin_amdgcn_s_barrier();
asm volatile("s_waitcnt vmcnt(4)");
__builtin_amdgcn_s_barrier();
G::load(As[toc][0], A, {0, 0, block_row * 2, 1}, sw_A);
G::load(Bs[toc][0], B, {0, 0, block_col * 2, 1}, sw_B);
G::load(Bs[toc][1], B, {0, 0, block_col * 2 + 1, 1}, sw_B);
asm volatile("s_waitcnt vmcnt(6)");
__builtin_amdgcn_s_barrier();
#pragma unroll 2
for (int k = 0; k < k_iters - 2; k++, tic ^= 1, toc ^= 1) {
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
load(b0, bs0);
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
load(a, as0);
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
asm volatile("s_waitcnt lgkmcnt(8)");
__builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
__builtin_amdgcn_s_setprio(2);
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
load(b1, bs1);
G::load(As[tic][0], A, {0, 0, block_row * 2, k + 2}, sw_A);
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_s_setprio(2);
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
load(a, as1);
G::load(Bs[tic][0], B, {0, 0, block_col * 2, k + 2}, sw_B);
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_s_setprio(2);
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
G::load(Bs[tic][1], B, {0, 0, block_col * 2 + 1, k + 2}, sw_B);
asm volatile("s_waitcnt vmcnt(6)");
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_s_setprio(2);
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
}
{ // Epilogue k = k_iters - 2
int k = k_iters - 2;
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
__builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
auto bs0 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][0], {warp_n, 0});
load(b0, bs0);
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
load(a, as0);
G::load(As[toc][1], A, {0, 0, block_row * 2 + 1, k + 1}, sw_A);
__builtin_amdgcn_s_barrier();
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_setprio(2);
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
load(b1, bs1);
__builtin_amdgcn_s_barrier();
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_setprio(2);
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
load(a, as1);
__builtin_amdgcn_s_barrier();
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_setprio(2);
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
auto bs0_next = subtile_inplace<REG_N, BLOCK_K>(Bs[toc][0], {warp_n, 0});
load(b0, bs0_next);
asm volatile("s_waitcnt vmcnt(4)");
__builtin_amdgcn_s_barrier();
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_setprio(2);
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
tic ^= 1; toc ^= 1;
}
{ // Epilogue k = k_iters - 1 (final)
int k = k_iters - 1;
load_scales_to_lds(smem_scales, scale_A_iter, scale_B_iter, block_m, block_n, k, M, N);
__builtin_amdgcn_s_barrier();
fp8e8m0_4 sa_h0 = pack_scales(smem_scales, 0, a_row_h0);
fp8e8m0_4 sa_h1 = pack_scales(smem_scales, 0, a_row_h1);
fp8e8m0_4 sb_h0 = pack_scales(smem_scales, 1024, b_row_h0);
fp8e8m0_4 sb_h1 = pack_scales(smem_scales, 1024, b_row_h1);
auto as0 = subtile_inplace<REG_M, BLOCK_K>(As[tic][0], {warp_m, 0});
load(a, as0);
__builtin_amdgcn_s_barrier();
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_setprio(2);
mma_ABt_scaled(cA, a, b0, cA, &sa_h0, &sb_h0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
auto bs1 = subtile_inplace<REG_N, BLOCK_K>(Bs[tic][1], {warp_n, 0});
load(b1, bs1);
asm volatile("s_waitcnt vmcnt(0)");
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_setprio(2);
mma_ABt_scaled(cB, a, b1, cB, &sa_h0, &sb_h1);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
auto as1 = subtile_inplace<REG_M, BLOCK_K>(As[tic][1], {warp_m, 0});
load(a, as1);
__builtin_amdgcn_s_barrier();
asm volatile("s_waitcnt lgkmcnt(0)");
__builtin_amdgcn_s_setprio(2);
mma_ABt_scaled(cC, a, b0, cC, &sa_h1, &sb_h0);
mma_ABt_scaled(cD, a, b1, cD, &sa_h1, &sb_h1);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_s_barrier();
}
store(C, cA, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + warp_n});
store(C, cB, {0, 0, block_row * WARPS_ROW * 2 + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n});
store(C, cC, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + warp_n});
store(C, cD, {0, 0, block_row * WARPS_ROW * 2 + WARPS_ROW + warp_m, block_col * WARPS_COL * 2 + WARPS_COL + warp_n});
}

View file

@ -230,6 +230,54 @@ def has_hipcc():
@unittest.skipUnless(has_hipcc(), "FP8 gemm requires hipcc to compile")
class TestGemmLlamaFP8(TestGemmLlama): dtype = FP8_DTYPE
# mxfp8: 1x32 block scaling along K, e8m0 scales packed iteration-major (K/128, dim) uint32
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
rows, K = x.shape
scale_K, k_iters = K // 32, K // 128
xb = x.reshape(rows, scale_K, 32).float()
amax = xb.abs().max(axis=-1)
e8 = (amax.log2().floor() + 127).clamp(0, 254)
e8 = (amax == 0).where(Tensor.zeros_like(e8), e8).cast(dtypes.uint8)
xq = (xb * (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1)).cast(FP8_DTYPE).reshape(rows, K)
packed = e8.reshape(rows, k_iters, 4).bitcast(dtypes.uint32).reshape(rows, k_iters).permute(1, 0)
return xq.contiguous(), e8, packed.contiguous()
def dequant_mxfp8(xq:Tensor, e8:Tensor) -> Tensor:
rows, K = xq.shape
scale = (e8.cast(dtypes.float32) - 127.0).exp2()
return (xq.float().reshape(rows, K // 32, 32) * scale.reshape(rows, K // 32, 1)).reshape(rows, K)
def run_mxfp8_gemm(M:int, N:int, K:int) -> None:
import functools
from extra.gemm.cdna_asm_gemm import custom_hk_mxfp8_gemm
Tensor.manual_seed(0)
a = (Tensor.randn(M, K, dtype=dtypes.float) * 0.5).realize()
b = (Tensor.randn(N, K, dtype=dtypes.float) * 0.5).realize()
a_q, a_e8, a_si = quantize_mxfp8(a)
b_q, b_e8, b_si = quantize_mxfp8(b)
Tensor.realize(a_q, a_e8, a_si, b_q, b_e8, b_si)
out = Tensor.invalids(1, M, N, dtype=dtypes.bfloat16, device=a.device)
tst = out.custom_kernel(a_q.unsqueeze(0), b_q, a_si, b_si, fxn=functools.partial(custom_hk_mxfp8_gemm, dname=a.device))[0].squeeze(0)
ref_mx = dequant_mxfp8(a_q, a_e8) @ dequant_mxfp8(b_q, b_e8).T
ref = a @ b.T
Tensor.realize(tst, ref_mx, ref)
if a.device.startswith("NULL"): return
err_mx = ((tst.float() - ref_mx).abs().mean() / ref_mx.abs().mean()).item()
err = ((tst.float() - ref).abs().mean() / ref.abs().mean()).item()
assert err_mx < 1e-2, f"kernel vs mxfp8 reference rel err {err_mx}"
assert err < 6e-2, f"kernel vs fp32 rel err {err}"
@unittest.skipUnless(has_hipcc(), "MXFP8 gemm requires hipcc to compile")
class TestGemmMXFP8(unittest.TestCase):
def setUp(self):
if not is_cdna4() or DEV.interface.startswith("MOCK"): self.skipTest("mxfp8 gemm is only for cdna4")
def test_simple(self): run_mxfp8_gemm(N:=getenv("N", 256), N, 2*128)
def test_rect(self): run_mxfp8_gemm(512, 256, 512)
def test_llama_ffn(self): run_mxfp8_gemm(8192, 14336, 4096)
def test_llama_ffn2(self): run_mxfp8_gemm(8192, 4096, 14336)
def test_llama_qkv(self): run_mxfp8_gemm(8192, 4096, 4096)
class TestMagicGu(unittest.TestCase):
def test_magicgu_matches_old(self):
from extra.gemm.cdna_asm_gemm import _magicgu_mulhi, TILE_M, TILE_N, TILE_K