mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama: mxfp8 (#16574)
This commit is contained in:
parent
b8aec4cce7
commit
e770805d21
5 changed files with 43 additions and 11 deletions
|
|
@ -1282,7 +1282,7 @@ def train_bert():
|
|||
previous_step = i
|
||||
|
||||
def train_llama3():
|
||||
from examples.mlperf.models.flat_llama import FlatTransformer, apply_grad, FP8_DTYPE
|
||||
from examples.mlperf.models.flat_llama import FlatTransformer, apply_grad, FP8_DTYPE, MXFP8
|
||||
from examples.llama3 import MODEL_PARAMS
|
||||
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
|
||||
from examples.mlperf.optim import GradAccClipAdamW
|
||||
|
|
@ -1447,7 +1447,12 @@ def train_llama3():
|
|||
idx = next(j for j, p in enumerate(optim.params) if p is w)
|
||||
master = optim.master_params[idx]
|
||||
inv = w._inv_scale if w._inv_scale.device == master.device else w._inv_scale.to(master.device)
|
||||
master.assign((master * inv.reshape(*inv.shape, *([1]*(w.ndim-inv.ndim)))).contiguous())
|
||||
if MXFP8:
|
||||
from extra.gemm.cdna_asm_gemm import _mx_block_scale
|
||||
bs = _mx_block_scale(inv.reshape(-1, inv.shape[-1])).reshape(w.shape)
|
||||
master.assign((master * bs).contiguous())
|
||||
else:
|
||||
master.assign((master * inv.reshape(*inv.shape, *([1]*(w.ndim-inv.ndim)))).contiguous())
|
||||
|
||||
# realize everything here
|
||||
if optim.master_params: Tensor.realize(*optim.master_params)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ FUSED_ADD_NORM_MUL_QUANTIZE = getenv("FUSED_ADD_NORM_MUL_QUANTIZE", 0)
|
|||
FUSED_SILU_W13 = getenv("FUSED_SILU_W13", 0)
|
||||
SPLIT_W13 = getenv("SPLIT_W13", 0)
|
||||
COLUMNWISE_WEIGHT_SCALE = getenv("COLUMNWISE_WEIGHT_SCALE", 0)
|
||||
MXFP8 = getenv("MXFP8", 0)
|
||||
|
||||
FP8_DTYPE = dtypes.fp8e4m3
|
||||
FP8_GRAD_DTYPE = dtypes.fp8e5m2
|
||||
|
|
@ -44,6 +45,16 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
|
|||
if can_use_asm_gemm(x, w.T): return (asm_gemm(x, w.T),)
|
||||
return (x @ w.T,)
|
||||
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
|
||||
if MXFP8:
|
||||
from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale
|
||||
x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
|
||||
if can_use_asm_gemm(x_q, w.T):
|
||||
out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale),
|
||||
mx_w_stored=True).reshape(*x.shape[:-1], w.shape[0])
|
||||
else:
|
||||
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*x.shape[:-1], x.shape[-1])
|
||||
out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T
|
||||
return out, (amax_x.detach() if amax_x is not None else None), x_q
|
||||
if x_fp8 is None:
|
||||
if FUSED_INPUT_QUANTIZE and amax_x is not None:
|
||||
from extra.llama_kernels.quantize_fp8_delayed import quantize_fp8_delayed
|
||||
|
|
@ -140,12 +151,16 @@ class FlatTransformer:
|
|||
self._fp8_grad_amax = {name: [_amax() for _ in range(n_layers)] for name in grad_names}
|
||||
w_scales = [("wqkv", s_qkv), ("wo", s_o), ("w2", s_2)]
|
||||
w_scales += [("w1", s_1), ("w3", s_3)] if SPLIT_W13 else [("w13", s_13)]
|
||||
self._fp8_inv_scale = {name: s.float().contiguous().is_param_(False) for name, s in w_scales}
|
||||
self._fp8_next_inv_scale = {name: s.float().contiguous().is_param_(False) for name, s in w_scales}
|
||||
self._fp8_inv_scale = {name: (s if MXFP8 else s.float()).contiguous().is_param_(False) for name, s in w_scales}
|
||||
self._fp8_next_inv_scale = {name: (s if MXFP8 else s.float()).contiguous().is_param_(False) for name, s in w_scales}
|
||||
|
||||
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02):
|
||||
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
|
||||
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
|
||||
if MXFP8:
|
||||
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
|
||||
w_q, w_e8, _ = quantize_mxfp8(w.reshape(self.n_layers * out_features, in_features))
|
||||
return w_q.reshape(self.n_layers, out_features, in_features), w_e8.reshape(self.n_layers, out_features, in_features // 32)
|
||||
amax = (w.abs().max(axis=2) if COLUMNWISE_WEIGHT_SCALE else w.abs().flatten(1).max(1)).detach()
|
||||
scale = FP8_MAX / (amax + 1e-8)
|
||||
inv_scale = (amax + 1e-8) / FP8_MAX
|
||||
|
|
@ -230,7 +245,7 @@ class FlatTransformer:
|
|||
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
|
||||
def _shard_fp8(name:str, axis:int):
|
||||
getattr(self, name).shard_(device, axis=axis)
|
||||
scale_axis = (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
|
||||
scale_axis = axis if MXFP8 else (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None
|
||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False)
|
||||
Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ STOCHASTIC_ROUND = getenv("STOCHASTIC_ROUND", 0)
|
|||
MASTER_WEIGHTS = getenv("MASTER_WEIGHTS", 0)
|
||||
FP8_AMAX_MARGIN = getenv("FP8_AMAX_MARGIN", 1.1)
|
||||
IMMEDIATE_SCALE = getenv("IMMEDIATE_SCALE", 0)
|
||||
MXFP8 = getenv("MXFP8", 0)
|
||||
|
||||
def stochastic_round_bf16(x:Tensor) -> Tensor:
|
||||
bits = x.bitcast(dtypes.uint32)
|
||||
|
|
@ -90,6 +91,13 @@ class GradAccClipAdamW(Optimizer):
|
|||
out = stochastic_round_bf16(new_w)
|
||||
return out.shard_like(t) if offloaded else out
|
||||
if t.dtype in dtypes.fp8s:
|
||||
if MXFP8:
|
||||
from extra.gemm.cdna_asm_gemm import quantize_mxfp8
|
||||
w_q, w_e8, _ = quantize_mxfp8(new_w.reshape(-1, new_w.shape[-1]))
|
||||
new_e8 = w_e8.reshape(t._inv_scale.shape)
|
||||
t._inv_scale.assign(new_e8.shard_like(t._inv_scale) if offloaded else new_e8)
|
||||
ret = w_q.reshape(new_w.shape)
|
||||
return ret.shard_like(t) if offloaded else ret
|
||||
from examples.mlperf.models.flat_llama import FP8_MAX
|
||||
if IMMEDIATE_SCALE:
|
||||
amax_axis = tuple(range(t._inv_scale.ndim, new_w.ndim))
|
||||
|
|
|
|||
|
|
@ -2660,7 +2660,9 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
|
|||
block_size = 256
|
||||
threads = UOp.special(64 * 8, "lidx0")
|
||||
workgroups = UOp.special((M // block_size) * (N // block_size), "gidx0")
|
||||
sink_inputs = (C.base, A.base, B.base, scale_A.base, scale_B.base, threads, workgroups)
|
||||
e_a = extra[0].base if len(extra) >= 1 else scale_A.base
|
||||
e_b = extra[1].base if len(extra) >= 2 else scale_B.base
|
||||
sink_inputs = (C.base, A.base, B.base, scale_A.base, scale_B.base, e_a, e_b, threads, workgroups)
|
||||
sink = UOp.sink(*sink_inputs,
|
||||
arg=KernelInfo(f"hk_mxfp8_gemm_{M}_{N}_{K}", estimates=Estimates(ops=2*M*N*K, mem=(M*K+N*K)*A.dtype.itemsize+M*N*C.dtype.itemsize)))
|
||||
kittens_path = pathlib.Path(__file__).parent.parent/"thunder"/"amd"
|
||||
|
|
@ -2876,7 +2878,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
|||
|
||||
# ** mxfp8 gemm backward
|
||||
|
||||
def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool):
|
||||
def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool, w_stored:bool=False):
|
||||
inputs = kernel.src[1:] # (out, a_q, b_q, a_si, b_si, a_e8, b_e8, [w_post])
|
||||
aq, bq = Tensor(inputs[1], device=inputs[1].device), Tensor(inputs[2], device=inputs[2].device)
|
||||
ae8, be8 = Tensor(inputs[5], device=inputs[5].device), Tensor(inputs[6], device=inputs[6].device)
|
||||
|
|
@ -2890,14 +2892,14 @@ def custom_mx_gemm_bw(gradient:UOp, kernel:UOp, has_w_post:bool):
|
|||
grad_b = asm_gemm(g.T, a_phys, mx=True)
|
||||
|
||||
grad_a = (grad_a * _mx_block_scale(ae8)).reshape(aq.shape)
|
||||
grad_b = grad_b * _mx_block_scale(be8)
|
||||
if not w_stored: grad_b = grad_b * _mx_block_scale(be8)
|
||||
if wp is not None: grad_b = grad_b / wp.reshape(-1, 1)
|
||||
return (None, grad_a.uop, grad_b.uop) + tuple(None for _ in inputs[3:])
|
||||
|
||||
# ** main gemm function
|
||||
|
||||
def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=None, grad_amax_state:Tensor|None=None,
|
||||
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None) -> Tensor:
|
||||
w_post_scale:Tensor|None=None, mx:bool=False, mx_scales:tuple|None=None, mx_w_stored:bool=False) -> Tensor:
|
||||
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
|
||||
counters["used"] += 1
|
||||
unfold_batch = a.ndim == 3 and isinstance(a.device, tuple) and a.uop.axis == 2 and b.uop.axis == 0
|
||||
|
|
@ -2939,7 +2941,7 @@ def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=N
|
|||
b_q, b_e8, b_si = quantize_mxfp8(b.T)
|
||||
has_w_post = w_post_scale is not None
|
||||
fxn = functools.partial(custom_hk_mxfp8_gemm, dname=dname)
|
||||
grad_fxn = functools.partial(custom_mx_gemm_bw, has_w_post=has_w_post)
|
||||
grad_fxn = functools.partial(custom_mx_gemm_bw, has_w_post=has_w_post, w_stored=mx_w_stored)
|
||||
extra = [w_post_scale] if w_post_scale is not None else []
|
||||
out = Tensor.custom_kernel(out, a_q.reshape(a.shape), b_q, a_si, b_si, a_e8, b_e8, *extra, fxn=fxn, grad_fxn=grad_fxn)[0]
|
||||
# fp8 gemm computes a@b.T, kernel multiplies output by x_scale * w_scale before bf16 store
|
||||
|
|
|
|||
|
|
@ -28,7 +28,9 @@ using G = kittens::group<NUM_WARPS>;
|
|||
|
||||
__global__ __launch_bounds__(512, 2) void mxfp8_gemm_kernel(bf16 *C_ptr, fp8e4m3 *A_ptr, fp8e4m3 *B_ptr,
|
||||
const uint32_t *__restrict__ scale_A_iter,
|
||||
const uint32_t *__restrict__ scale_B_iter) {
|
||||
const uint32_t *__restrict__ scale_B_iter,
|
||||
const uint8_t *__restrict__ a_e8_unused,
|
||||
const uint8_t *__restrict__ b_e8_unused) {
|
||||
constexpr int M = GEMM_M, N = GEMM_N, K = GEMM_K;
|
||||
|
||||
kittens::gl<fp8e4m3, 1, 1, M, K> A{A_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue