mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fp8 gemm inv_scale in epilogue (#16625)
* fuse scale * remove python inv_scale * more inv_scale removal * more cleanups * cleaner * diff polish * work * rename * simpler * simpler * compute * c * Revert "c" This reverts commit8941fec7ca. * Revert "compute" This reverts commit9db573a6d3. * Revert "simpler" This reverts commit910ad33f87. * Revert "simpler" This reverts commitbf75d235a1. * s_g * update types * less diff noise * remove
This commit is contained in:
parent
4dc51aff6e
commit
f998b9930a
5 changed files with 48 additions and 41 deletions
|
|
@ -37,7 +37,7 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
|
|||
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal(), new_amax
|
||||
|
||||
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
|
||||
x_fp8:Tensor|None=None, x_scale:Tensor|None=None, x_new_amax:Tensor|None=None,
|
||||
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
|
||||
grad_amax_state:Tensor|None=None) -> tuple[Tensor,...]:
|
||||
if not fp8:
|
||||
if ASM_GEMM:
|
||||
|
|
@ -58,24 +58,25 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
|
|||
if x_fp8 is None:
|
||||
if FUSED_INPUT_QUANTIZE and amax_x is not None:
|
||||
from extra.llama_kernels.quantize_fp8_delayed import quantize_fp8_delayed
|
||||
x_fp8, x_scale, x_new_amax, _ = quantize_fp8_delayed(x, amax_x, FP8_DTYPE)
|
||||
x_fp8, _, x_new_amax, _ = quantize_fp8_delayed(x, amax_x, FP8_DTYPE)
|
||||
else:
|
||||
x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x)
|
||||
x_fp8, _, x_new_amax = quantize_fp8(x, amax_state=amax_x)
|
||||
if ASM_GEMM:
|
||||
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
||||
if can_use_asm_gemm(x_fp8, w.T):
|
||||
assert amax_x is not None
|
||||
if COLUMNWISE_WEIGHT_SCALE:
|
||||
out = asm_gemm(x_fp8, w.T, x_scale=x_scale, grad_amax_state=grad_amax_state, w_post_scale=w_inv_scale)
|
||||
out = asm_gemm(x_fp8, w.T, x_scale=amax_x, grad_amax_state=grad_amax_state, w_post_scale=w_inv_scale)
|
||||
else:
|
||||
out = asm_gemm(x_fp8, w.T, x_scale=x_scale, w_scale=w_inv_scale, grad_amax_state=grad_amax_state)
|
||||
out = asm_gemm(x_fp8, w.T, x_scale=amax_x, w_scale=w_inv_scale, grad_amax_state=grad_amax_state)
|
||||
return out, x_new_amax, x_fp8
|
||||
return (x_fp8.dot(w.T, dtype=dtypes.float) * x_scale * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8
|
||||
return (x_fp8.dot(w.T, dtype=dtypes.float) * ((amax_x.float() + 1e-8) / FP8_MAX) * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8
|
||||
|
||||
def norm_quantize_matmul(x:Tensor, norm:Tensor, w:Tensor, w_inv_scale:Tensor, eps:float, amax_x:Tensor, grad_amax_state:Tensor):
|
||||
if FUSED_ADD_NORM_MUL_QUANTIZE:
|
||||
from extra.llama_kernels.fused_rmsnorm_mul_quantize_fp8 import fused_rmsnorm_mul_quantize_fp8
|
||||
x_fp8, x_inv_scale, new_amax, x_normed, rrms = fused_rmsnorm_mul_quantize_fp8(x, norm, amax_x, eps, FP8_DTYPE)
|
||||
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
|
||||
x_fp8, new_amax, x_normed, rrms = fused_rmsnorm_mul_quantize_fp8(x, norm, amax_x, eps, FP8_DTYPE)
|
||||
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, amax_x=amax_x, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
|
||||
return out, x_normed, rrms, ret
|
||||
x_normed, rrms = rmsnorm(x, eps)
|
||||
out, *ret = matmul(x_normed * norm, w, amax_x=amax_x, w_inv_scale=w_inv_scale, grad_amax_state=grad_amax_state)
|
||||
|
|
@ -85,8 +86,8 @@ def add_norm_quantize_matmul(x:Tensor, residual:Tensor, norm:Tensor, w:Tensor, w
|
|||
grad_amax_state:Tensor|None=None):
|
||||
if FUSED_ADD_NORM_MUL_QUANTIZE:
|
||||
from extra.llama_kernels.fused_rmsnorm_mul_quantize_fp8 import fused_add_rmsnorm_mul_quantize_fp8
|
||||
x_fp8, x_inv_scale, new_amax, h, x_normed, rrms = fused_add_rmsnorm_mul_quantize_fp8(x, residual, norm, amax_x, eps, FP8_DTYPE)
|
||||
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
|
||||
x_fp8, new_amax, h, x_normed, rrms = fused_add_rmsnorm_mul_quantize_fp8(x, residual, norm, amax_x, eps, FP8_DTYPE)
|
||||
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, amax_x=amax_x, x_new_amax=new_amax, grad_amax_state=grad_amax_state)
|
||||
return out, h, x_normed, rrms, ret
|
||||
h = x + residual
|
||||
x_normed, rrms = rmsnorm(h, eps)
|
||||
|
|
@ -98,8 +99,8 @@ def silu_w13_quantize_matmul(x_w13:Tensor, w2:Tensor, s_2:Tensor,
|
|||
grad_amax_xw13:Tensor, grad_amax_xout:Tensor):
|
||||
if FUSED_SILU_W13:
|
||||
from extra.llama_kernels.cast_amax import fused_quantize_fp8_w13
|
||||
x2_fp8, x2_inv_scale, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_x2, FP8_DTYPE, grad_amax_state=grad_amax_xw13)
|
||||
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, x_scale=x2_inv_scale, x_new_amax=new_amax_x2, grad_amax_state=grad_amax_xout)
|
||||
x2_fp8, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_x2, FP8_DTYPE, grad_amax_state=grad_amax_xw13)
|
||||
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, amax_x=amax_x2, x_new_amax=new_amax_x2, grad_amax_state=grad_amax_xout)
|
||||
return out, ret
|
||||
hidden = x_w13.shape[-1] // 2
|
||||
x_w1, x_w3 = x_w13[..., :hidden], x_w13[..., hidden:]
|
||||
|
|
|
|||
|
|
@ -2630,7 +2630,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
|
|||
@functools.cache
|
||||
def custom_hk_fp8_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str, scale_mode:int=3) -> UOp:
|
||||
# scale_mode: 0=no scale, 1=x only, 2=w only, 3=both
|
||||
n_scales = (1 if scale_mode & 1 else 0) + (1 if scale_mode & 2 else 0)
|
||||
n_scales = (1 if scale_mode & 1 else 0) + (1 if scale_mode & 2 else 0) + (1 if scale_mode & 4 else 0)
|
||||
scales, extra = args[:n_scales], args[n_scales:]
|
||||
M, K = A.shape[0]*A.shape[1], A.shape[2]
|
||||
N, K2 = B.shape[(1 if B.ndim == 3 else 0):]
|
||||
|
|
@ -2808,13 +2808,15 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
out, a, b = inputs[:3]
|
||||
i = 3
|
||||
s_x = inputs[i]; i += 1
|
||||
has_w = n_scales == 2
|
||||
has_w = n_scales >= 2
|
||||
s_w = inputs[i] if has_w else None; i += has_w
|
||||
s_g = inputs[i] if n_scales == 3 else None; i += (n_scales == 3)
|
||||
grad_amax_state = inputs[i] if has_grad_amax else None; i += has_grad_amax
|
||||
w_post = inputs[i] if has_w_post else None
|
||||
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
|
||||
s_x_t = Tensor(s_x, device=a.device)
|
||||
s_w_t = Tensor(s_w, device=a.device) if has_w else None
|
||||
s_g_t = Tensor(s_g, device=a.device) if s_g is not None else None
|
||||
w_post_t = Tensor(w_post, device=a.device) if has_w_post else None
|
||||
g_t = g_t[:a.shape[0]]
|
||||
from extra.llama_kernels.cast_amax import _grad_fp8_mailbox
|
||||
|
|
@ -2839,7 +2841,8 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
store_effect = grad_amax_state.store(new_grad_amax.uop)
|
||||
g_fp8 = Tensor(g_fp8.contiguous().uop.after(store_effect), device=a.device)
|
||||
# dgrad: uses g_scale * x_scale * w_scale (only when scalar)
|
||||
grad_a = asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t, w_scale=s_w_t) if has_w else asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t)
|
||||
if s_g_t is not None: g_scale = g_scale * s_g_t
|
||||
grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=s_w_t, g_scale=g_scale) if has_w else asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=g_scale)
|
||||
# wgrad: no w_scale
|
||||
g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1])
|
||||
if getenv("FAST_FP8_TRANSPOSE", 0) and g_fp8_2d.shape[0] % 64 == 0 and g_fp8_2d.shape[1] % 64 == 0:
|
||||
|
|
@ -2847,7 +2850,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
g_fp8_T = fast_fp8_transpose(g_fp8_2d)
|
||||
else:
|
||||
g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1)
|
||||
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=g_scale * s_x_t)
|
||||
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=s_x_t, w_scale=g_scale)
|
||||
# wgrad: rescale if not scalar
|
||||
if w_post_t is not None:
|
||||
grad_b = grad_b / w_post_t.reshape(*w_post_t.shape, *([1]*(grad_b.ndim - w_post_t.ndim)))
|
||||
|
|
@ -2900,7 +2903,7 @@ def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool, w_stored:bool=F
|
|||
# ** main gemm function
|
||||
|
||||
def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=None, grad_amax_state:Tensor|None=None,
|
||||
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None, mx_w_stored:bool=False) -> Tensor:
|
||||
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None, mx_w_stored:bool=False, g_scale:Tensor|None=None) -> Tensor:
|
||||
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
|
||||
counters["used"] += 1
|
||||
unfold_batch = a.ndim == 3 and isinstance(a.device, tuple) and a.uop.axis == 2 and b.uop.axis == 0
|
||||
|
|
@ -2947,8 +2950,8 @@ def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=N
|
|||
out = Tensor.custom_kernel(out, a_q.reshape(a.shape), b_q, a_si, b_si, a_e8, b_e8, *extra, fxn=fxn, grad_fxn=grad_fxn)[0]
|
||||
# fp8 gemm computes a@b.T, kernel multiplies output by x_scale * w_scale before bf16 store
|
||||
elif a.dtype == FP8_DTYPE:
|
||||
scales = tuple(s for s in (x_scale, w_scale) if s is not None)
|
||||
scale_mode = (1 if x_scale is not None else 0) | (2 if w_scale is not None else 0)
|
||||
scales = tuple(s for s in (x_scale, w_scale, g_scale) if s is not None)
|
||||
scale_mode = (1 if x_scale is not None else 0) | (2 if w_scale is not None else 0) | (4 if g_scale is not None else 0)
|
||||
extra = ([grad_amax_state] if grad_amax_state is not None else []) + ([w_post_scale] if w_post_scale is not None else [])
|
||||
fxn = functools.partial(custom_hk_fp8_gemm, dname=dname, scale_mode=scale_mode)
|
||||
bw = functools.partial(custom_gemm_bw, n_scales=len(scales), has_grad_amax=grad_amax_state is not None, has_w_post=w_post_scale is not None)
|
||||
|
|
|
|||
|
|
@ -59,8 +59,8 @@ def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
|
|||
_grad_fp8_mailbox[grad_xw13_uop] = (grad_xw13_fp8_uop, inv_scale.uop)
|
||||
return (None, None, grad_xw13_uop, None, None)
|
||||
|
||||
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_state:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, inv_scale, new_amax)
|
||||
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_state:Tensor) -> tuple[Tensor, Tensor]:
|
||||
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, new_amax)
|
||||
# grad_amax_state: delayed amax for grad_xw13 fp8 quantization in the backward.
|
||||
assert xw13.dtype == dtypes.bfloat16, f"expected bf16, got {xw13.dtype}"
|
||||
MBS, SEQ, H2 = xw13.shape
|
||||
|
|
@ -72,5 +72,4 @@ def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_
|
|||
fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname_of(xw13.device))
|
||||
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, grad_amax_state,
|
||||
fxn=fxn, grad_fxn=_fused_quantize_bwd_w13)
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
return fp8_out, inv_scale, scalar_amax(amax_buf)
|
||||
return fp8_out, scalar_amax(amax_buf)
|
||||
|
|
|
|||
|
|
@ -112,8 +112,8 @@ def _fused_add_bwd(*args, **kwargs):
|
|||
grad_h, grad_w = _bwd_common(fp8_grad_u, h_grad_u, x_u, x_normed_u, rrms_u, weight_u, amax_state_u, kernel)
|
||||
return (None, None, None, None, None, grad_h, grad_h, grad_w, None)
|
||||
|
||||
def fused_rmsnorm_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
# NOTE: rmsnorm(x) * weight -> fp8 + amax. Returns (fp8, inv_scale, new_amax, x_normed, rrms).
|
||||
def fused_rmsnorm_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
# NOTE: rmsnorm(x) * weight -> fp8 + amax. Returns (fp8, new_amax, x_normed, rrms).
|
||||
# x_normed + rrms are saved for the rmsnorm backward (also recomputed here from x regs).
|
||||
assert x.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16
|
||||
assert x.shape[-1] == weight.shape[-1], f"HIDDEN mismatch: x={x.shape}, weight={weight.shape}"
|
||||
|
|
@ -127,13 +127,12 @@ def fused_rmsnorm_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, e
|
|||
fxn = functools.partial(_custom_fwd, dname=dname_of(x.device), eps_val=eps)
|
||||
fp8_out, x_normed_out, rrms_out, amax_buf, *_ = Tensor.custom_kernel(
|
||||
fp8_out, x_normed_out, rrms_out, amax_buf, x, weight, amax_state, fxn=fxn, grad_fxn=_fused_bwd)
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
return fp8_out, inv_scale, scalar_amax(amax_buf), x_normed_out, rrms_out
|
||||
return fp8_out, scalar_amax(amax_buf), x_normed_out, rrms_out
|
||||
|
||||
def fused_add_rmsnorm_mul_quantize_fp8(x:Tensor, residual:Tensor, weight:Tensor, amax_state:Tensor,
|
||||
eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
eps:float, fp8_dtype) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
# NOTE: h = x + residual; y_normed = rmsnorm(h); fp8 = quantize(y_normed * weight).
|
||||
# Returns (fp8, inv_scale, new_amax, h, x_normed, rrms). h is also written so downstream can
|
||||
# Returns (fp8, new_amax, h, x_normed, rrms). h is also written so downstream can
|
||||
# reuse it without recomputing x+residual — eliminates the separate residual-add kernel.
|
||||
assert x.dtype == dtypes.bfloat16 and residual.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16
|
||||
assert x.shape == residual.shape
|
||||
|
|
@ -149,5 +148,4 @@ def fused_add_rmsnorm_mul_quantize_fp8(x:Tensor, residual:Tensor, weight:Tensor,
|
|||
fp8_out, h_out, x_normed_out, rrms_out, amax_buf, *_ = Tensor.custom_kernel(
|
||||
fp8_out, h_out, x_normed_out, rrms_out, amax_buf, x, residual, weight, amax_state,
|
||||
fxn=fxn, grad_fxn=_fused_add_bwd)
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
return fp8_out, inv_scale, scalar_amax(amax_buf), h_out, x_normed_out, rrms_out
|
||||
return fp8_out, scalar_amax(amax_buf), h_out, x_normed_out, rrms_out
|
||||
|
|
|
|||
|
|
@ -99,12 +99,14 @@ using G = kittens::group<NUM_WARPS>;
|
|||
#endif
|
||||
|
||||
__global__ __launch_bounds__(512, 2) void hk_fp8_gemm(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr
|
||||
#if SCALE_MODE == 1
|
||||
#if SCALE_MODE & 1
|
||||
, float *x_scale_ptr
|
||||
#elif SCALE_MODE == 2
|
||||
#endif
|
||||
#if SCALE_MODE & 2
|
||||
, float *w_scale_ptr
|
||||
#elif SCALE_MODE == 3
|
||||
, float *x_scale_ptr, float *w_scale_ptr
|
||||
#endif
|
||||
#if SCALE_MODE & 4
|
||||
, float *g_scale_ptr
|
||||
#endif
|
||||
) {
|
||||
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
|
||||
|
|
@ -347,12 +349,16 @@ __global__ __launch_bounds__(512, 2) void hk_fp8_gemm(bf16 *C_ptr, fp8e4m3 *A_pt
|
|||
|
||||
// apply x_scale * w_scale before bf16 store to prevent overflow
|
||||
#if SCALE_MODE != 0
|
||||
#if SCALE_MODE == 1
|
||||
float scale = *x_scale_ptr;
|
||||
#elif SCALE_MODE == 2
|
||||
float scale = *w_scale_ptr;
|
||||
#elif SCALE_MODE == 3
|
||||
float scale = *x_scale_ptr * *w_scale_ptr;
|
||||
float scale = 1.0f;
|
||||
#if SCALE_MODE & 1
|
||||
float x_scale = (*x_scale_ptr + 1e-08f) * (1.0f / 448.0f);
|
||||
scale *= x_scale;
|
||||
#endif
|
||||
#if SCALE_MODE & 2
|
||||
scale *= *w_scale_ptr;
|
||||
#endif
|
||||
#if SCALE_MODE & 4
|
||||
scale *= *g_scale_ptr;
|
||||
#endif
|
||||
|
||||
mul(cA, cA, scale);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue