mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
simpler
This commit is contained in:
parent
bf75d235a1
commit
910ad33f87
4 changed files with 14 additions and 16 deletions
|
|
@ -58,7 +58,7 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
|
|||
if x_fp8 is None:
|
||||
if FUSED_INPUT_QUANTIZE and amax_x is not None:
|
||||
from extra.llama_kernels.quantize_fp8_delayed import quantize_fp8_delayed
|
||||
x_fp8, _, x_new_amax, _ = quantize_fp8_delayed(x, amax_x, FP8_DTYPE)
|
||||
x_fp8, x_new_amax, _ = quantize_fp8_delayed(x, amax_x, FP8_DTYPE)
|
||||
else:
|
||||
x_fp8, x_new_amax = quantize_fp8(x, amax_state=amax_x)
|
||||
if ASM_GEMM:
|
||||
|
|
|
|||
|
|
@ -2815,6 +2815,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
|
||||
s_x_t = Tensor(s_x, device=a.device)
|
||||
s_w_t = Tensor(s_w, device=a.device) if has_w else None
|
||||
g_x_t = Tensor(grad_amax_state, device=a.device) if grad_amax_state is not None else None
|
||||
w_post_t = Tensor(w_post, device=a.device) if has_w_post else None
|
||||
g_t = g_t[:a.shape[0]]
|
||||
from extra.llama_kernels.cast_amax import _grad_fp8_mailbox
|
||||
|
|
@ -2822,24 +2823,24 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
gbase = gradient.base if hasattr(gradient, "base") else gradient
|
||||
mailbox_entry = _grad_fp8_mailbox.pop(gbase, None) or _grad_fp8_mailbox.pop(gradient, None)
|
||||
if mailbox_entry is not None:
|
||||
g_fp8_u, inv_scale_u = mailbox_entry
|
||||
g_fp8_u, grad_amax_state_u = mailbox_entry
|
||||
g_fp8 = Tensor(g_fp8_u, device=a.device)[:a.shape[0]]
|
||||
g_scale = Tensor(inv_scale_u, device=a.device)
|
||||
g_scale = Tensor(grad_amax_state_u, device=a.device)
|
||||
else:
|
||||
assert grad_amax_state is not None, "fp8 matmul bwd needs either a mailbox entry or a grad_amax_state"
|
||||
if getenv("CURRENT_GRAD_SCALE", 0):
|
||||
g_fp8, g_scale, _ = quantize_fp8(g_t, amax_state=None)
|
||||
elif getenv("FUSED_GRAD_QUANTIZE", 0):
|
||||
g_fp8, g_scale, _, store_effect = quantize_fp8_delayed(g_t, Tensor(grad_amax_state, device=a.device))
|
||||
# if getenv("CURRENT_GRAD_SCALE", 0):
|
||||
# g_fp8, _ = quantize_fp8(g_t, amax_state=None)
|
||||
if getenv("FUSED_GRAD_QUANTIZE", 0):
|
||||
g_fp8, _, store_effect = quantize_fp8_delayed(g_t, g_x_t)
|
||||
assert g_fp8.uop.op is Ops.AFTER, f"expected AFTER, got {g_fp8.uop.op}"
|
||||
g_fp8 = Tensor(g_fp8.uop.replace(src=g_fp8.uop.src + (store_effect,)), device=a.device)
|
||||
else:
|
||||
grad_amax_t = Tensor(grad_amax_state, device=a.device)
|
||||
g_fp8, g_scale, new_grad_amax = quantize_fp8(g_t, amax_state=grad_amax_t)
|
||||
g_fp8, new_grad_amax = quantize_fp8(g_t, amax_state=grad_amax_t)
|
||||
store_effect = grad_amax_state.store(new_grad_amax.uop)
|
||||
g_fp8 = Tensor(g_fp8.contiguous().uop.after(store_effect), device=a.device)
|
||||
# dgrad: uses g_scale * x_scale * w_scale (only when scalar)
|
||||
grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=s_w_t, g_scale=g_scale) if has_w else asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=g_scale)
|
||||
grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=s_w_t, g_scale=g_x_t) if has_w else asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=g_x_t)
|
||||
# wgrad: no w_scale
|
||||
g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1])
|
||||
if getenv("FAST_FP8_TRANSPOSE", 0) and g_fp8_2d.shape[0] % 64 == 0 and g_fp8_2d.shape[1] % 64 == 0:
|
||||
|
|
@ -2847,7 +2848,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
g_fp8_T = fast_fp8_transpose(g_fp8_2d)
|
||||
else:
|
||||
g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1)
|
||||
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=s_x_t, w_scale=g_scale)
|
||||
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=s_x_t, w_scale=g_x_t)
|
||||
# wgrad: rescale if not scalar
|
||||
if w_post_t is not None:
|
||||
grad_b = grad_b / w_post_t.reshape(*w_post_t.shape, *([1]*(grad_b.ndim - w_post_t.ndim)))
|
||||
|
|
|
|||
|
|
@ -50,13 +50,12 @@ def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
|
|||
Tensor(xw13, device=device), Tensor(gradient, device=device).cast(dtypes.bfloat16),
|
||||
Tensor(amax_state, device=device), grad_amax_state_t, fxn=fxn)
|
||||
grad_xw13_uop = grad_xw13_fp8.uop.cast(dtypes.bfloat16)
|
||||
inv_scale = (grad_amax_state_t.float() + 1e-8) / FP8_MAX
|
||||
new_grad_amax = scalar_amax(grad_amax_buf)
|
||||
store_effect = grad_amax_state_t.uop.store(new_grad_amax.uop)
|
||||
assert grad_xw13_fp8.uop.op is Ops.AFTER, f"expected AFTER, got {grad_xw13_fp8.uop.op}"
|
||||
grad_xw13_fp8_uop = grad_xw13_fp8.uop.replace(src=grad_xw13_fp8.uop.src + (store_effect,))
|
||||
# Stash fp8 companion for cdna_asm_gemm's bwd to attach to grad_a.
|
||||
_grad_fp8_mailbox[grad_xw13_uop] = (grad_xw13_fp8_uop, inv_scale.uop)
|
||||
_grad_fp8_mailbox[grad_xw13_uop] = (grad_xw13_fp8_uop, grad_amax_state_t.uop)
|
||||
return (None, None, grad_xw13_uop, None, None)
|
||||
|
||||
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype, grad_amax_state:Tensor) -> tuple[Tensor, Tensor]:
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ def _quantize_fp8_delayed_bwd(gradient:UOp, kernel:UOp):
|
|||
grad_x = (Tensor(gradient, device=device).float() * scale).cast(dtypes.bfloat16)
|
||||
return (None, None, grad_x.uop, None)
|
||||
|
||||
def quantize_fp8_delayed(x:Tensor, amax_state:Tensor, fp8_dtype=dtypes.fp8e4m3) -> tuple[Tensor, Tensor, Tensor, UOp]:
|
||||
def quantize_fp8_delayed(x:Tensor, amax_state:Tensor, fp8_dtype=dtypes.fp8e4m3) -> tuple[Tensor, Tensor, UOp]:
|
||||
# NOTE: one-pass bf16 -> fp8 quantize with delayed scaling. Returns (fp8, inv_scale, new_amax, store_effect).
|
||||
# Fused kernel reads x once and writes fp8 + per-WG |x| partials (then a small reduce produces scalar new_amax).
|
||||
# store_effect writes new_amax into amax_state's buffer — the caller must thread it into a realized
|
||||
|
|
@ -85,10 +85,8 @@ def quantize_fp8_delayed(x:Tensor, amax_state:Tensor, fp8_dtype=dtypes.fp8e4m3)
|
|||
fp8_out, amax_partial, *_ = Tensor.custom_kernel(fp8_out, amax_partial, x, amax_state,
|
||||
fxn=fxn, grad_fxn=_quantize_fp8_delayed_bwd)
|
||||
new_amax = scalar_amax(amax_partial)
|
||||
# NOTE: this exists for the fp8 gemm bw
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
store_effect = amax_state.uop.store(new_amax.uop)
|
||||
return fp8_out, inv_scale, new_amax, store_effect
|
||||
return fp8_out, new_amax, store_effect
|
||||
|
||||
def quantize_fp8_scalar(x:Tensor, amax_state:Tensor, fp8_dtype=dtypes.fp8e4m3) -> Tensor:
|
||||
# NOTE: pure one-pass bf16 -> fp8 quantize with delayed scalar scale. No amax computation.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue