mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama: don't save weight (#16252)
This commit is contained in:
parent
18b102f355
commit
a3d59faef6
1 changed files with 2 additions and 2 deletions
|
|
@ -52,8 +52,8 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
|
|||
if ASM_GEMM:
|
||||
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
||||
if can_use_asm_gemm(x_fp8, w.T):
|
||||
return asm_gemm(x_fp8, w.T, x_scale=x_scale, w_scale=w_inv_scale, grad_amax_state=grad_amax_state), x_new_amax, x_fp8, w
|
||||
return (x_fp8.dot(w.T, dtype=dtypes.float) * x_scale * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8, w
|
||||
return asm_gemm(x_fp8, w.T, x_scale=x_scale, w_scale=w_inv_scale, grad_amax_state=grad_amax_state), x_new_amax, x_fp8
|
||||
return (x_fp8.dot(w.T, dtype=dtypes.float) * x_scale * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8
|
||||
|
||||
def norm_quantize_matmul(x:Tensor, norm:Tensor, w:Tensor, w_inv_scale:Tensor, eps:float, amax_x:Tensor, grad_amax_state:Tensor):
|
||||
if FUSED_ADD_NORM_MUL_QUANTIZE:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue