This commit is contained in:
qazal 2026-06-15 16:22:04 +08:00
commit 910ad33f87
4 changed files with 14 additions and 16 deletions

View file

@ -58,7 +58,7 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
if x_fp8 is None:
if FUSED_INPUT_QUANTIZE and amax_x is not None:
from extra.llama_kernels.quantize_fp8_delayed import quantize_fp8_delayed
x_fp8, _, x_new_amax, _ = quantize_fp8_delayed(x, amax_x, FP8_DTYPE)
x_fp8, x_new_amax, _ = quantize_fp8_delayed(x, amax_x, FP8_DTYPE)
else:
x_fp8, x_new_amax = quantize_fp8(x, amax_state=amax_x)
if ASM_GEMM:

View file

@ -2815,6 +2815,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
s_x_t = Tensor(s_x, device=a.device)
s_w_t = Tensor(s_w, device=a.device) if has_w else None
g_x_t = Tensor(grad_amax_state, device=a.device) if grad_amax_state is not None else None
w_post_t = Tensor(w_post, device=a.device) if has_w_post else None
g_t = g_t[:a.shape[0]]
from extra.llama_kernels.cast_amax import _grad_fp8_mailbox
@ -2822,24 +2823,24 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
gbase = gradient.base if hasattr(gradient, "base") else gradient
mailbox_entry = _grad_fp8_mailbox.pop(gbase, None) or _grad_fp8_mailbox.pop(gradient, None)
if mailbox_entry is not None:
g_fp8_u, inv_scale_u = mailbox_entry
g_fp8_u, grad_amax_state_u = mailbox_entry
g_fp8 = Tensor(g_fp8_u, device=a.device)[:a.shape[0]]
g_scale = Tensor(inv_scale_u, device=a.device)
g_scale = Tensor(grad_amax_state_u, device=a.device)
else:
assert grad_amax_state is not None, "fp8 matmul bwd needs either a mailbox entry or a grad_amax_state"
if getenv("CURRENT_GRAD_SCALE", 0):
g_fp8, g_scale, _ = quantize_fp8(g_t, amax_state=None)
elif getenv("FUSED_GRAD_QUANTIZE", 0):
g_fp8, g_scale, _, store_effect = quantize_fp8_delayed(g_t, Tensor(grad_amax_state, device=a.device))
# if getenv("CURRENT_GRAD_SCALE", 0):
# g_fp8, _ = quantize_fp8(g_t, amax_state=None)
if getenv("FUSED_GRAD_QUANTIZE", 0):
g_fp8, _, store_effect = quantize_fp8_delayed(g_t, g_x_t)
assert g_fp8.uop.op is Ops.AFTER, f"expected AFTER, got {g_fp8.uop.op}"
g_fp8 = Tensor(g_fp8.uop.replace(src=g_fp8.uop.src + (store_effect,)), device=a.device)
else:
grad_amax_t = Tensor(grad_amax_state, device=a.device)
g_fp8, g_scale, new_grad_amax = quantize_fp8(g_t, amax_state=grad_amax_t)
g_fp8, new_grad_amax = quantize_fp8(g_t, amax_state=grad_amax_t)
store_effect = grad_amax_state.store(new_grad_amax.uop)
g_fp8 = Tensor(g_fp8.contiguous().uop.after(store_effect), device=a.device)
# dgrad: uses g_scale * x_scale * w_scale (only when scalar)
grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=s_w_t, g_scale=g_scale) if has_w else asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=g_scale)
grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=s_w_t, g_scale=g_x_t) if has_w else asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=g_x_t)
# wgrad: no w_scale
g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1])
if getenv("FAST_FP8_TRANSPOSE", 0) and g_fp8_2d.shape[0] % 64 == 0 and g_fp8_2d.shape[1] % 64 == 0:
@ -2847,7 +2848,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
g_fp8_T = fast_fp8_transpose(g_fp8_2d)
else:
g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1)
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=s_x_t, w_scale=g_scale)
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=s_x_t, w_scale=g_x_t)
# wgrad: rescale if not scalar
if w_post_t is not None:
grad_b = grad_b / w_post_t.reshape(*w_post_t.shape, *([1]*(grad_b.ndim - w_post_t.ndim)))

View file

@ -50,13 +50,12 @@ def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
Tensor(xw13, device=device), Tensor(gradient, device=device).cast(dtypes.bfloat16),
Tensor(amax_state, device=device), grad_amax_state_t, fxn=fxn)
grad_xw13_uop = grad_xw13_fp8.uop.cast(dtypes.bfloat16)
inv_scale = (grad_amax_state_t.float() + 1e-8) / FP8_MAX
new_grad_amax = scalar_amax(grad_amax_buf)
store_effect = grad_amax_state_t.uop.store(new_grad_amax.uop)
assert grad_xw13_fp8.uop.op is Ops.AFTER, f"expected AFTER, got {grad_xw13_fp8.uop.op}"
grad_xw13_fp8_uop = grad_xw13_fp8.uop.replace(src=grad_xw13_fp8.uop.src + (store_effect,))
# Stash fp8 companion for cdna_asm_gemm's bwd to attach to grad_a.
_grad_fp8_mailbox[grad_xw13_uop] = (grad_xw13_fp8_uop, inv_scale.uop)
_grad_fp8_mailbox[grad_xw13_uop] = (grad_xw13_fp8_uop, grad_amax_state_t.uop)
return (None, None, grad_xw13_uop, None, None)
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_state:Tensor) -> tuple[Tensor, Tensor]:

View file

@ -69,7 +69,7 @@ def _quantize_fp8_delayed_bwd(gradient:UOp, kernel:UOp):
grad_x = (Tensor(gradient, device=device).float() * scale).cast(dtypes.bfloat16)
return (None, None, grad_x.uop, None)
def quantize_fp8_delayed(x:Tensor, amax_state:Tensor, fp8_dtype=dtypes.fp8e4m3) -> tuple[Tensor, Tensor, Tensor, UOp]:
def quantize_fp8_delayed(x:Tensor, amax_state:Tensor, fp8_dtype=dtypes.fp8e4m3) -> tuple[Tensor, Tensor, UOp]:
# NOTE: one-pass bf16 -> fp8 quantize with delayed scaling. Returns (fp8, inv_scale, new_amax, store_effect).
# Fused kernel reads x once and writes fp8 + per-WG |x| partials (then a small reduce produces scalar new_amax).
# store_effect writes new_amax into amax_state's buffer — the caller must thread it into a realized
@ -85,10 +85,8 @@ def quantize_fp8_delayed(x:Tensor, amax_state:Tensor, fp8_dtype=dtypes.fp8e4m3)
fp8_out, amax_partial, *_ = Tensor.custom_kernel(fp8_out, amax_partial, x, amax_state,
fxn=fxn, grad_fxn=_quantize_fp8_delayed_bwd)
new_amax = scalar_amax(amax_partial)
# NOTE: this exists for the fp8 gemm bw
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
store_effect = amax_state.uop.store(new_amax.uop)
return fp8_out, inv_scale, new_amax, store_effect
return fp8_out, new_amax, store_effect
def quantize_fp8_scalar(x:Tensor, amax_state:Tensor, fp8_dtype=dtypes.fp8e4m3) -> Tensor:
# NOTE: pure one-pass bf16 -> fp8 quantize with delayed scalar scale. No amax computation.