llama: mxfp8 (#16574)

This commit is contained in:
wozeparrot 2026-06-12 01:15:24 -04:00 committed by GitHub
commit e770805d21
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 43 additions and 11 deletions

View file

@ -2660,7 +2660,9 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
block_size = 256
threads = UOp.special(64 * 8, "lidx0")
workgroups = UOp.special((M // block_size) * (N // block_size), "gidx0")
sink_inputs = (C.base, A.base, B.base, scale_A.base, scale_B.base, threads, workgroups)
e_a = extra[0].base if len(extra) >= 1 else scale_A.base
e_b = extra[1].base if len(extra) >= 2 else scale_B.base
sink_inputs = (C.base, A.base, B.base, scale_A.base, scale_B.base, e_a, e_b, threads, workgroups)
sink = UOp.sink(*sink_inputs,
arg=KernelInfo(f"hk_mxfp8_gemm_{M}_{N}_{K}", estimates=Estimates(ops=2*M*N*K, mem=(M*K+N*K)*A.dtype.itemsize+M*N*C.dtype.itemsize)))
kittens_path = pathlib.Path(__file__).parent.parent/"thunder"/"amd"
@ -2876,7 +2878,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
# ** mxfp8 gemm backward
def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool):
def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool, w_stored:bool=False):
inputs = kernel.src[1:] # (out, a_q, b_q, a_si, b_si, a_e8, b_e8, [w_post])
aq, bq = Tensor(inputs[1], device=inputs[1].device), Tensor(inputs[2], device=inputs[2].device)
ae8, be8 = Tensor(inputs[5], device=inputs[5].device), Tensor(inputs[6], device=inputs[6].device)
@ -2890,14 +2892,14 @@ def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool):
grad_b = asm_gemm(g.T, a_phys, mx=True)
grad_a = (grad_a * _mx_block_scale(ae8)).reshape(aq.shape)
grad_b = grad_b * _mx_block_scale(be8)
if not w_stored: grad_b = grad_b * _mx_block_scale(be8)
if wp is not None: grad_b = grad_b / wp.reshape(-1, 1)
return (None, grad_a.uop, grad_b.uop) + tuple(None for _ in inputs[3:])
# ** main gemm function
def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=None, grad_amax_state:Tensor|None=None,
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None) -> Tensor:
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None, mx_w_stored:bool=False) -> Tensor:
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
counters["used"] += 1
unfold_batch = a.ndim == 3 and isinstance(a.device, tuple) and a.uop.axis == 2 and b.uop.axis == 0
@ -2939,7 +2941,7 @@ def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=N
b_q, b_e8, b_si = quantize_mxfp8(b.T)
has_w_post = w_post_scale is not None
fxn = functools.partial(custom_hk_mxfp8_gemm, dname=dname)
grad_fxn = functools.partial(custom_mx_gemm_bw, has_w_post=has_w_post)
grad_fxn = functools.partial(custom_mx_gemm_bw, has_w_post=has_w_post, w_stored=mx_w_stored)
extra = [w_post_scale] if w_post_scale is not None else []
out = Tensor.custom_kernel(out, a_q.reshape(a.shape), b_q, a_si, b_si, a_e8, b_e8, *extra, fxn=fxn, grad_fxn=grad_fxn)[0]
# fp8 gemm computes a@b.T, kernel multiplies output by x_scale * w_scale before bf16 store

View file

@ -28,7 +28,9 @@ using G = kittens::group<NUM_WARPS>;
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr,
const uint32_t *__restrict__ scale_A_iter,
const uint32_t *__restrict__ scale_B_iter) {
const uint32_t *__restrict__ scale_B_iter,
const uint8_t *__restrict__ a_e8_unused,
const uint8_t *__restrict__ b_e8_unused) {
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
kittens::gl<fp8e4m3, 1, 1, M, K> A{A_ptr, nullptr, nullptr, nullptr, nullptr};