mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama: mxfp8 (#16574)
This commit is contained in:
parent
b8aec4cce7
commit
e770805d21
5 changed files with 43 additions and 11 deletions
|
|
@ -2660,7 +2660,9 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
|
|||
block_size = 256
|
||||
threads = UOp.special(64 * 8, "lidx0")
|
||||
workgroups = UOp.special((M // block_size) * (N // block_size), "gidx0")
|
||||
sink_inputs = (C.base, A.base, B.base, scale_A.base, scale_B.base, threads, workgroups)
|
||||
e_a = extra[0].base if len(extra) >= 1 else scale_A.base
|
||||
e_b = extra[1].base if len(extra) >= 2 else scale_B.base
|
||||
sink_inputs = (C.base, A.base, B.base, scale_A.base, scale_B.base, e_a, e_b, threads, workgroups)
|
||||
sink = UOp.sink(*sink_inputs,
|
||||
arg=KernelInfo(f"hk_mxfp8_gemm_{M}_{N}_{K}", estimates=Estimates(ops=2*M*N*K, mem=(M*K+N*K)*A.dtype.itemsize+M*N*C.dtype.itemsize)))
|
||||
kittens_path = pathlib.Path(__file__).parent.parent/"thunder"/"amd"
|
||||
|
|
@ -2876,7 +2878,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
|
||||
# ** mxfp8 gemm backward
|
||||
|
||||
def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool):
|
||||
def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool, w_stored:bool=False):
|
||||
inputs = kernel.src[1:] # (out, a_q, b_q, a_si, b_si, a_e8, b_e8, [w_post])
|
||||
aq, bq = Tensor(inputs[1], device=inputs[1].device), Tensor(inputs[2], device=inputs[2].device)
|
||||
ae8, be8 = Tensor(inputs[5], device=inputs[5].device), Tensor(inputs[6], device=inputs[6].device)
|
||||
|
|
@ -2890,14 +2892,14 @@ def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool):
|
|||
grad_b = asm_gemm(g.T, a_phys, mx=True)
|
||||
|
||||
grad_a = (grad_a * _mx_block_scale(ae8)).reshape(aq.shape)
|
||||
grad_b = grad_b * _mx_block_scale(be8)
|
||||
if not w_stored: grad_b = grad_b * _mx_block_scale(be8)
|
||||
if wp is not None: grad_b = grad_b / wp.reshape(-1, 1)
|
||||
return (None, grad_a.uop, grad_b.uop) + tuple(None for _ in inputs[3:])
|
||||
|
||||
# ** main gemm function
|
||||
|
||||
def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=None, grad_amax_state:Tensor|None=None,
|
||||
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None) -> Tensor:
|
||||
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None, mx_w_stored:bool=False) -> Tensor:
|
||||
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
|
||||
counters["used"] += 1
|
||||
unfold_batch = a.ndim == 3 and isinstance(a.device, tuple) and a.uop.axis == 2 and b.uop.axis == 0
|
||||
|
|
@ -2939,7 +2941,7 @@ def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=N
|
|||
b_q, b_e8, b_si = quantize_mxfp8(b.T)
|
||||
has_w_post = w_post_scale is not None
|
||||
fxn = functools.partial(custom_hk_mxfp8_gemm, dname=dname)
|
||||
grad_fxn = functools.partial(custom_mx_gemm_bw, has_w_post=has_w_post)
|
||||
grad_fxn = functools.partial(custom_mx_gemm_bw, has_w_post=has_w_post, w_stored=mx_w_stored)
|
||||
extra = [w_post_scale] if w_post_scale is not None else []
|
||||
out = Tensor.custom_kernel(out, a_q.reshape(a.shape), b_q, a_si, b_si, a_e8, b_e8, *extra, fxn=fxn, grad_fxn=grad_fxn)[0]
|
||||
# fp8 gemm computes a@b.T, kernel multiplies output by x_scale * w_scale before bf16 store
|
||||
|
|
|
|||
|
|
@ -28,7 +28,9 @@ using G = kittens::group<NUM_WARPS>;
|
|||
|
||||
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr,
|
||||
const uint32_t *__restrict__ scale_A_iter,
|
||||
const uint32_t *__restrict__ scale_B_iter) {
|
||||
const uint32_t *__restrict__ scale_B_iter,
|
||||
const uint8_t *__restrict__ a_e8_unused,
|
||||
const uint8_t *__restrict__ b_e8_unused) {
|
||||
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
|
||||
|
||||
kittens::gl<fp8e4m3, 1, 1, M, K> A{A_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue