This commit is contained in:
qazal 2026-06-15 15:47:27 +08:00
commit 7e4e895daa
2 changed files with 4 additions and 2 deletions

View file

@ -2629,7 +2629,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
@functools.cache
def custom_hk_fp8_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str, scale_mode:int=3) -> UOp:
# scale_mode: 0=no scale, 1=x only, 2=w only, 3=both, 4=g_scale (bw)
# scale_mode: 0=no scale, 1=x only, 2=w only, 3=both, 4=all 3
n_scales = (1 if scale_mode & 1 else 0) + (1 if scale_mode & 2 else 0) + (1 if scale_mode & 4 else 0)
scales, extra = args[:n_scales], args[n_scales:]
M, K = A.shape[0]*A.shape[1], A.shape[2]
@ -2843,6 +2843,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
# dgrad: uses g_scale * x_scale * w_scale (only when scalar)
if s_extra_t is not None: g_scale = g_scale * s_extra_t
if has_w: grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=s_w_t, g_scale=g_scale)
# do x * g scale in the gemm
else: grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=g_scale)
# wgrad: no w_scale
g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1])
@ -2851,6 +2852,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
g_fp8_T = fast_fp8_transpose(g_fp8_2d)
else:
g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1)
# do x * g scale in the gemm
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=s_x_t, w_scale=g_scale)
# wgrad: rescale if not scalar
if w_post_t is not None:

View file

@ -93,7 +93,7 @@ constexpr int NUM_WARPS = 8;
using G = kittens::group<NUM_WARPS>;
// scale_mode: 0=no scale, 1=x only, 2=w only, 3=both, 4=g_scale (bw)
// scale_mode: 0=no scale, 1=x only, 2=w only, 3=both, 4=all 3
#ifndef SCALE_MODE
#define SCALE_MODE 3
#endif