mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
work
This commit is contained in:
parent
f777ff7180
commit
7e4e895daa
2 changed files with 4 additions and 2 deletions
|
|
@ -2629,7 +2629,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
|
|||
|
||||
@functools.cache
|
||||
def custom_hk_fp8_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str, scale_mode:int=3) -> UOp:
|
||||
# scale_mode: 0=no scale, 1=x only, 2=w only, 3=both, 4=g_scale (bw)
|
||||
# scale_mode: 0=no scale, 1=x only, 2=w only, 3=both, 4=all 3
|
||||
n_scales = (1 if scale_mode & 1 else 0) + (1 if scale_mode & 2 else 0) + (1 if scale_mode & 4 else 0)
|
||||
scales, extra = args[:n_scales], args[n_scales:]
|
||||
M, K = A.shape[0]*A.shape[1], A.shape[2]
|
||||
|
|
@ -2843,6 +2843,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
# dgrad: uses g_scale * x_scale * w_scale (only when scalar)
|
||||
if s_extra_t is not None: g_scale = g_scale * s_extra_t
|
||||
if has_w: grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=s_w_t, g_scale=g_scale)
|
||||
# do x * g scale in the gemm
|
||||
else: grad_a = asm_gemm(g_fp8, b_t, x_scale=s_x_t, w_scale=g_scale)
|
||||
# wgrad: no w_scale
|
||||
g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1])
|
||||
|
|
@ -2851,6 +2852,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
g_fp8_T = fast_fp8_transpose(g_fp8_2d)
|
||||
else:
|
||||
g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1)
|
||||
# do x * g scale in the gemm
|
||||
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=s_x_t, w_scale=g_scale)
|
||||
# wgrad: rescale if not scalar
|
||||
if w_post_t is not None:
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ constexpr int NUM_WARPS = 8;
|
|||
|
||||
using G = kittens::group<NUM_WARPS>;
|
||||
|
||||
// scale_mode: 0=no scale, 1=x only, 2=w only, 3=both, 4=g_scale (bw)
|
||||
// scale_mode: 0=no scale, 1=x only, 2=w only, 3=both, 4=all 3
|
||||
#ifndef SCALE_MODE
|
||||
#define SCALE_MODE 3
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue