This commit is contained in:
qazal 2026-06-15 15:31:35 +08:00
commit 122bb10d4d
3 changed files with 16 additions and 21 deletions

View file

@ -66,9 +66,9 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
if can_use_asm_gemm(x_fp8, w.T):
assert amax_x is not None
if COLUMNWISE_WEIGHT_SCALE:
out = asm_gemm(x_fp8, w.T, x_scale=amax_x, grad_amax_state=grad_amax_state, w_post_scale=w_inv_scale, x_scale_is_amax=True)
out = asm_gemm(x_fp8, w.T, x_scale=amax_x, grad_amax_state=grad_amax_state, w_post_scale=w_inv_scale)
else:
out = asm_gemm(x_fp8, w.T, x_scale=amax_x, w_scale=w_inv_scale, grad_amax_state=grad_amax_state, x_scale_is_amax=True)
out = asm_gemm(x_fp8, w.T, x_scale=amax_x, w_scale=w_inv_scale, grad_amax_state=grad_amax_state)
return out, x_new_amax, x_fp8
return (x_fp8.dot(w.T, dtype=dtypes.float) * ((amax_state.float() + 1e-8) / FP8_MAX) * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8

View file

@ -2628,8 +2628,8 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
# ** FP8 GEMM custom kernel
@functools.cache
def custom_hk_fp8_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str, scale_mode:int=3, x_scale_is_amax:bool=False) -> UOp:
# scale_mode: 0=no scale, 1=x only, 2=w only, 3=both, 4=extra
def custom_hk_fp8_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str, scale_mode:int=3) -> UOp:
# scale_mode: 0=no scale, 1=x only, 2=w only, 3=both, 4=g_scale (bw)
n_scales = (1 if scale_mode & 1 else 0) + (1 if scale_mode & 2 else 0) + (1 if scale_mode & 4 else 0)
scales, extra = args[:n_scales], args[n_scales:]
M, K = A.shape[0]*A.shape[1], A.shape[2]
@ -2645,7 +2645,7 @@ def custom_hk_fp8_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str, scale_mode:int
src = (kittens_path/"gemm_fp8.cpp").read_text()
lib = HIPCCCompiler("gfx950", [f"-I{(kittens_path/'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-ffast-math",
"-DHIP_ENABLE_WARP_SYNC_BUILTINS", f"-DGEMM_M={M}", f"-DGEMM_N={N}", f"-DGEMM_K={K}",
f"-DSCALE_MODE={scale_mode}", f"-DX_SCALE_IS_AMAX={int(x_scale_is_amax)}"]).compile_cached(src)
f"-DSCALE_MODE={scale_mode}"]).compile_cached(src)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src),
UOp(Ops.BINARY, arg=lib)))
@ -2802,7 +2802,7 @@ def hk_bf16_atb_gemm(a:Tensor, b:Tensor) -> Tensor:
# ** backward gemm, might use the asm gemm
def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=False, has_w_post:bool=False, x_scale_is_amax:bool=False):
def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=False, has_w_post:bool=False):
inputs = kernel.src[1:]
if inputs[1].dtype == FP8_DTYPE:
out, a, b = inputs[:3]
@ -2842,8 +2842,8 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
g_fp8 = Tensor(g_fp8.contiguous().uop.after(store_effect), device=a.device)
# dgrad: uses g_scale * x_scale * w_scale (only when scalar)
if s_extra_t is not None: g_scale = g_scale * s_extra_t
if has_w: grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=s_w_t, extra_scale=g_scale, x_scale_is_amax=x_scale_is_amax)
else: grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=g_scale, x_scale_is_amax=x_scale_is_amax)
if has_w: grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=s_w_t, g_scale=g_scale)
else: grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=g_scale)
# wgrad: no w_scale
g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1])
if getenv("FAST_FP8_TRANSPOSE", 0) and g_fp8_2d.shape[0] % 64 == 0 and g_fp8_2d.shape[1] % 64 == 0:
@ -2851,7 +2851,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
g_fp8_T = fast_fp8_transpose(g_fp8_2d)
else:
g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1)
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=s_x_t, w_scale=g_scale, x_scale_is_amax=x_scale_is_amax)
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=s_x_t, w_scale=g_scale)
# wgrad: rescale if not scalar
if w_post_t is not None:
grad_b = grad_b / w_post_t.reshape(*w_post_t.shape, *([1]*(grad_b.ndim - w_post_t.ndim)))
@ -2905,7 +2905,7 @@ def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool, w_stored:bool=F
def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=None, grad_amax_state:Tensor|None=None,
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None, mx_w_stored:bool=False,
x_scale_is_amax:bool=False, extra_scale:Tensor|None=None) -> Tensor:
g_scale:Tensor|None=None) -> Tensor:
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
counters["used"] += 1
unfold_batch = a.ndim == 3 and isinstance(a.device, tuple) and a.uop.axis == 2 and b.uop.axis == 0
@ -2952,12 +2952,11 @@ def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=N
out = Tensor.custom_kernel(out, a_q.reshape(a.shape), b_q, a_si, b_si, a_e8, b_e8, *extra, fxn=fxn, grad_fxn=grad_fxn)[0]
# fp8 gemm computes a@b.T, kernel multiplies output by x_scale * w_scale before bf16 store
elif a.dtype == FP8_DTYPE:
scales = tuple(s for s in (x_scale, w_scale, extra_scale) if s is not None)
scale_mode = (1 if x_scale is not None else 0) | (2 if w_scale is not None else 0) | (4 if extra_scale is not None else 0)
scales = tuple(s for s in (x_scale, w_scale, g_scale) if s is not None)
scale_mode = (1 if x_scale is not None else 0) | (2 if w_scale is not None else 0) | (4 if g_scale is not None else 0)
extra = ([grad_amax_state] if grad_amax_state is not None else []) + ([w_post_scale] if w_post_scale is not None else [])
fxn = functools.partial(custom_hk_fp8_gemm, dname=dname, scale_mode=scale_mode, x_scale_is_amax=x_scale_is_amax)
bw = functools.partial(custom_gemm_bw, n_scales=len(scales), has_grad_amax=grad_amax_state is not None,
has_w_post=w_post_scale is not None, x_scale_is_amax=x_scale_is_amax)
fxn = functools.partial(custom_hk_fp8_gemm, dname=dname, scale_mode=scale_mode)
bw = functools.partial(custom_gemm_bw, n_scales=len(scales), has_grad_amax=grad_amax_state is not None, has_w_post=w_post_scale is not None)
out = Tensor.custom_kernel(out, a, b.T, *scales, *extra, fxn=fxn, grad_fxn=bw)[0]
elif a.dtype == dtypes.bfloat16 and getenv("USE_HK_BF16_GEMM"):
out = Tensor.custom_kernel(out, a, b.T, b, fxn=functools.partial(custom_hk_bf16_gemm, dname=dname), grad_fxn=custom_gemm_bw)[0]

View file

@ -93,7 +93,7 @@ constexpr int NUM_WARPS = 8;
using G = kittens::group<NUM_WARPS>;
// scale_mode: 0=no scale, 1=x only, 2=w only, 3=both, 4=extra
// scale_mode: 0=no scale, 1=x only, 2=w only, 3=both, 4=g_scale (bw)
#ifndef SCALE_MODE
#define SCALE_MODE 3
#endif
@ -351,11 +351,7 @@ __global__ __launch_bounds__(512, 2) void hk_fp8_gemm(bf16 *C_ptr, fp8e4m3 *A_pt
#if SCALE_MODE != 0
float scale = 1.0f;
#if SCALE_MODE & 1
float x_scale =
#if X_SCALE_IS_AMAX
(*x_scale_ptr + 1e-08f) * (1.0f / 448.0f);
#else
*x_scale_ptr;
float x_scale = (*x_scale_ptr + 1e-08f) * (1.0f / 448.0f);
#endif
scale *= x_scale;
#endif