mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
simpler
This commit is contained in:
parent
8325dcea43
commit
bf75d235a1
1 changed files with 1 additions and 7 deletions
|
|
@ -2810,13 +2810,11 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
s_x = inputs[i]; i += 1
|
||||
has_w = n_scales >= 2
|
||||
s_w = inputs[i] if has_w else None; i += has_w
|
||||
s_extra = inputs[i] if n_scales == 3 else None; i += (n_scales == 3)
|
||||
grad_amax_state = inputs[i] if has_grad_amax else None; i += has_grad_amax
|
||||
w_post = inputs[i] if has_w_post else None
|
||||
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
|
||||
s_x_t = Tensor(s_x, device=a.device)
|
||||
s_w_t = Tensor(s_w, device=a.device) if has_w else None
|
||||
s_g_t = Tensor(s_extra, device=a.device) if s_extra is not None else None
|
||||
w_post_t = Tensor(w_post, device=a.device) if has_w_post else None
|
||||
g_t = g_t[:a.shape[0]]
|
||||
from extra.llama_kernels.cast_amax import _grad_fp8_mailbox
|
||||
|
|
@ -2841,10 +2839,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
store_effect = grad_amax_state.store(new_grad_amax.uop)
|
||||
g_fp8 = Tensor(g_fp8.contiguous().uop.after(store_effect), device=a.device)
|
||||
# dgrad: uses g_scale * x_scale * w_scale (only when scalar)
|
||||
if s_g_t is not None: g_scale = g_scale * s_g_t
|
||||
if has_w: grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=s_w_t, g_scale=g_scale)
|
||||
# do x * g scale in the gemm
|
||||
else: grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=g_scale)
|
||||
grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=s_w_t, g_scale=g_scale) if has_w else asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=g_scale)
|
||||
# wgrad: no w_scale
|
||||
g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1])
|
||||
if getenv("FAST_FP8_TRANSPOSE", 0) and g_fp8_2d.shape[0] % 64 == 0 and g_fp8_2d.shape[1] % 64 == 0:
|
||||
|
|
@ -2852,7 +2847,6 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
g_fp8_T = fast_fp8_transpose(g_fp8_2d)
|
||||
else:
|
||||
g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1)
|
||||
# do x * g scale in the gemm
|
||||
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=s_x_t, w_scale=g_scale)
|
||||
# wgrad: rescale if not scalar
|
||||
if w_post_t is not None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue