mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama: move llama kernels to llama_kernels (#15952)
This commit is contained in:
parent
987b6dd193
commit
5e861cd2c4
13 changed files with 364 additions and 358 deletions
|
|
@ -1282,7 +1282,7 @@ def train_bert():
|
|||
previous_step = i
|
||||
|
||||
def train_llama3():
|
||||
from examples.mlperf.models.flat_llama import FlatTransformer, apply_grad, FP8, FP8_DTYPE
|
||||
from examples.mlperf.models.flat_llama import FlatTransformer, apply_grad, FP8_DTYPE
|
||||
from examples.llama3 import MODEL_PARAMS
|
||||
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
|
||||
from examples.mlperf.optim import GradAccClipAdamW
|
||||
|
|
@ -1432,18 +1432,17 @@ def train_llama3():
|
|||
print(f"loading optim checkpoint from {fn}")
|
||||
load_state_dict(scheduler, safe_load(fn), realize=False)
|
||||
|
||||
fp8_amax = [t for ts in model._fp8_amax.values() for t in ts] if FP8 else []
|
||||
fp8_inv_scales = list(model._fp8_inv_scale.values()) if FP8 else []
|
||||
fp8_amax = [t for ts in model._fp8_amax.values() for t in ts]
|
||||
fp8_inv_scales = list(model._fp8_inv_scale.values())
|
||||
|
||||
if FP8:
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
model_state = get_state_dict(model)
|
||||
for wname in ["wqkv", "wo", "w13", "w2"]:
|
||||
w = model_state[wname]
|
||||
w._inv_scale = model._fp8_inv_scale[wname]
|
||||
if optim.master_params:
|
||||
idx = next(j for j, p in enumerate(optim.params) if p is w)
|
||||
optim.master_params[idx].assign((optim.master_params[idx] * w._inv_scale.reshape(-1, *([1]*(w.ndim-1)))).contiguous())
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
model_state = get_state_dict(model)
|
||||
for wname in ["wqkv", "wo", "w13", "w2"]:
|
||||
w = model_state[wname]
|
||||
w._inv_scale = model._fp8_inv_scale[wname]
|
||||
if optim.master_params:
|
||||
idx = next(j for j, p in enumerate(optim.params) if p is w)
|
||||
optim.master_params[idx].assign((optim.master_params[idx] * w._inv_scale.reshape(-1, *([1]*(w.ndim-1)))).contiguous())
|
||||
|
||||
@TinyJit
|
||||
def minibatch(tokens:Tensor):
|
||||
|
|
@ -1452,7 +1451,7 @@ def train_llama3():
|
|||
if not is_sharding: tokens = tokens.to(None)
|
||||
logits:Tensor = model(tokens[:, :-1])
|
||||
if getenv("FAST_CE", 0):
|
||||
from extra.amax.cast_amax import fused_ce_loss
|
||||
from extra.llama_kernels.fused_ce import fused_ce_loss
|
||||
loss = fused_ce_loss(logits.cast(dtypes.bfloat16), tokens[:, 1:], label_smoothing=0.0)
|
||||
else:
|
||||
loss = vocab_mask.where(-1e9, logits).sparse_categorical_crossentropy(tokens[:, 1:])
|
||||
|
|
@ -1559,7 +1558,7 @@ def train_llama3():
|
|||
|
||||
mem_gb = GlobalCounters.mem_used / 1e9
|
||||
gflops = GlobalCounters.global_ops / 1e9 / dev_time
|
||||
mfu = ((6 * num_params * SEQLEN * GBS) / (dev_time * device_count * (4.6e15 if FP8 else 2.3e15))) * 100
|
||||
mfu = ((6 * num_params * SEQLEN * GBS) / (dev_time * device_count * 4.6e15)) * 100
|
||||
tqdm.write(
|
||||
f"{i:5} {step_time:.3f} s step, {gbs_time:.3f} s gbs, {optim_time:.3f} s optim, {data_time:.3f} s data, {loss:.4f} loss, " \
|
||||
f"{lr:.12f} LR, {grad_norm:.6f} grad_norm, {mem_gb:.2f} GB used, {gflops:9.2f} GFLOPS, {mfu:5.2f}% MFU")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import math, os, functools
|
||||
import math, os
|
||||
if __name__ == "__main__":
|
||||
os.environ["DEFAULT_FLOAT"] = "bfloat16"
|
||||
os.environ["OPTIM_DTYPE"] = "bfloat16"
|
||||
|
|
@ -16,65 +16,60 @@ from tinygrad import Tensor, nn, function, getenv, dtypes, TinyJit
|
|||
from tinygrad.helpers import Timing, colored, GlobalCounters, profile_marker
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from extra.models.llama import apply_rotary_emb, precompute_freqs_cis
|
||||
from extra.llama_kernels.rmsnorm import rmsnorm
|
||||
from extra.llama_kernels import FP8_MAX, local_abs_max
|
||||
|
||||
FP8 = getenv("FP8", 0)
|
||||
ASM_GEMM = getenv("ASM_GEMM", 0)
|
||||
|
||||
FP8_DTYPE = dtypes.fp8e4m3
|
||||
FP8_GRAD_DTYPE = dtypes.fp8e5m2
|
||||
FP8_MAX = 448.0
|
||||
|
||||
# per-device abs max without allreduce (matches TE delayed scaling behavior)
|
||||
@functools.cache
|
||||
def _local_abs_max_fxn(x_p, device):
|
||||
x = Tensor(x_p, device=device)
|
||||
inner = Tensor(x.uop.src[0]) if x.uop.op is Ops.MULTI else x
|
||||
return (inner.abs().max(),)
|
||||
|
||||
def _local_abs_max(x:Tensor) -> Tensor:
|
||||
param = x.as_param(0)
|
||||
fxn = _local_abs_max_fxn(param.uop, x.device)
|
||||
return Tensor(fxn[0].uop.call(x.uop).gettuple(0))
|
||||
|
||||
def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
|
||||
new_amax = (_local_abs_max(x) if isinstance(x.device, tuple) else x.abs().max()).detach()
|
||||
new_amax = (local_abs_max(x) if isinstance(x.device, tuple) else x.abs().max()).detach()
|
||||
scale = FP8_MAX / ((amax_state if amax_state is not None else new_amax) + 1e-8)
|
||||
x_scaled = x * scale
|
||||
x_clamped = x_scaled + (x_scaled.detach().clamp(-FP8_MAX, FP8_MAX) - x_scaled.detach()) # STE
|
||||
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal(), new_amax
|
||||
|
||||
def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
|
||||
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
|
||||
x_fp8:Tensor|None=None, x_scale:Tensor|None=None, x_new_amax:Tensor|None=None) -> tuple[Tensor,...]:
|
||||
if not fp8:
|
||||
if getenv("ASM_GEMM"):
|
||||
if ASM_GEMM:
|
||||
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
||||
if can_use_asm_gemm(x, w.T): return (asm_gemm(x, w.T),)
|
||||
return (x @ w.T,)
|
||||
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
|
||||
if x_fp8 is None: x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x)
|
||||
if getenv("ASM_GEMM"):
|
||||
if ASM_GEMM:
|
||||
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
||||
if can_use_asm_gemm(x_fp8, w.T): return asm_gemm(x_fp8, w.T, x_scale=x_scale, w_scale=w_inv_scale), x_new_amax, x_fp8, w
|
||||
return x_fp8.dot(w.T, dtype=dtypes.float) * x_scale * w_inv_scale, x_new_amax, x_fp8, w
|
||||
return (x_fp8.dot(w.T, dtype=dtypes.float) * x_scale * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8, w
|
||||
|
||||
def _rmsnorm_fwd(x_in:Tensor, eps:float) -> tuple[Tensor, Tensor]:
|
||||
x = x_in.float()
|
||||
rrms = (x.square().mean(-1, keepdim=True) + eps).rsqrt()
|
||||
return (x * rrms).cast(x_in.dtype), rrms
|
||||
def norm_mul_quantize_matmul(x:Tensor, norm:Tensor, amax_x, w_inv_scale, w:Tensor, eps:float):
|
||||
FUSED_NORM_MUL_QUANTIZE = getenv("FUSED_NORM_MUL_QUANTIZE", 0)
|
||||
normed, rrms = rmsnorm(x, eps)
|
||||
if FUSED_NORM_MUL_QUANTIZE:
|
||||
from extra.llama_kernels.fused_mul_quantize_fp8 import fused_mul_quantize_fp8
|
||||
amax_s = amax_x if amax_x is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=normed.device)
|
||||
x_fp8, x_inv_scale, new_amax = fused_mul_quantize_fp8(normed, norm, amax_s, FP8_DTYPE)
|
||||
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax)
|
||||
else:
|
||||
x = normed * norm
|
||||
out, *ret = matmul(x, w, amax_x=amax_x, w_inv_scale=w_inv_scale)
|
||||
return out, normed, rrms, ret
|
||||
|
||||
@functools.cache
|
||||
def _rmsnorm_fwd_fxn(x_in_p, eps, device):
|
||||
return _rmsnorm_fwd(Tensor(x_in_p, device=device), eps)
|
||||
|
||||
def _rmsnorm_bwd(grad:UOp, call:UOp) -> tuple:
|
||||
x_normed = Tensor(call.gettuple(0)).float()
|
||||
do_float = Tensor(grad).float()
|
||||
d_x = Tensor(call.gettuple(1)) * (do_float - x_normed * (do_float * x_normed).mean(-1, keepdim=True))
|
||||
return (d_x.cast(call.src[1].dtype).uop,)
|
||||
|
||||
def rmsnorm(x_in:Tensor, eps:float) -> tuple[Tensor, Tensor]:
|
||||
fxn = _rmsnorm_fwd_fxn(x_in.as_param(0).uop, eps, x_in.device)
|
||||
call = UOp.maketuple(fxn[0].uop, fxn[1].uop).call(x_in.uop, grad_fxn=_rmsnorm_bwd)
|
||||
return Tensor(call.gettuple(0)), Tensor(call.gettuple(1))
|
||||
def silu_w13_matmul(x_w13:Tensor, w2:Tensor, amax_x2, s_2):
|
||||
FUSED_SILU_W13 = getenv("FUSED_SILU_W13", 0)
|
||||
if FUSED_SILU_W13:
|
||||
from extra.llama_kernels.cast_amax import fused_quantize_fp8_w13
|
||||
amax_s = amax_x2 if amax_x2 is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x_w13.device)
|
||||
x2_fp8, x2_inv_scale, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_s, FP8_DTYPE)
|
||||
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, x_scale=x2_inv_scale, x_new_amax=new_amax_x2)
|
||||
else:
|
||||
hidden_dim = x_w13.shape[-1] // 2
|
||||
x_w1, x_w3 = x_w13[..., :hidden_dim], x_w13[..., hidden_dim:]
|
||||
out, *ret = matmul(x_w1.silu() * x_w3, w2, amax_x=amax_x2, w_inv_scale=s_2)
|
||||
return out, ret
|
||||
|
||||
class FlatTransformer:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size:int, n_kv_heads:int|None=None,
|
||||
|
|
@ -90,7 +85,7 @@ class FlatTransformer:
|
|||
scaled_std = 0.02 / math.sqrt(2 * n_layers)
|
||||
|
||||
# Attention
|
||||
self._init_inv_scales = [] # populated by lin_per_layer when FP8
|
||||
self._init_inv_scales = [] # populated by lin_per_layer
|
||||
self.wqkv = self.lin_per_layer(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2)
|
||||
self.wo = self.lin_per_layer(self.n_heads * self.head_dim, dim, std=scaled_std)
|
||||
|
||||
|
|
@ -109,21 +104,19 @@ class FlatTransformer:
|
|||
self.output = Tensor.normal(1, vocab_size, dim, mean=0.0, std=0.02, dtype=dtypes.bfloat16)
|
||||
self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_context * 2, rope_theta).contiguous().requires_grad_(False)
|
||||
|
||||
if FP8:
|
||||
def _amax(): return Tensor.full((), FP8_MAX).contiguous().requires_grad_(False)
|
||||
names = ["xqkv", "xo", "x13", "x2"]
|
||||
self._fp8_amax = {name: [_amax() for _ in range(n_layers)] for name in names}
|
||||
# per-weight inv_scale: single (n_layers,) float32 tensor per weight (kernel reads float* pointers)
|
||||
w_names = ["wqkv", "wo", "w13", "w2"]
|
||||
self._fp8_inv_scale = {}
|
||||
for wname, inv_scales in zip(w_names, self._init_inv_scales):
|
||||
self._fp8_inv_scale[wname] = inv_scales.float().contiguous().requires_grad_(False)
|
||||
del self._init_inv_scales
|
||||
def _amax(): return Tensor.full((), FP8_MAX).contiguous().requires_grad_(False)
|
||||
names = ["xqkv", "xo", "x13", "x2"]
|
||||
self._fp8_amax = {name: [_amax() for _ in range(n_layers)] for name in names}
|
||||
# per-weight inv_scale: single (n_layers,) float32 tensor per weight (kernel reads float* pointers)
|
||||
w_names = ["wqkv", "wo", "w13", "w2"]
|
||||
self._fp8_inv_scale = {}
|
||||
for wname, inv_scales in zip(w_names, self._init_inv_scales):
|
||||
self._fp8_inv_scale[wname] = inv_scales.float().contiguous().requires_grad_(False)
|
||||
del self._init_inv_scales
|
||||
|
||||
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02):
|
||||
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
|
||||
if getenv("ZEROS", 0): w = Tensor.zeros(self.n_layers, out_features, in_features)
|
||||
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
|
||||
if not FP8: return w
|
||||
# per-layer scaled fp8 cast: fill the fp8 range for best precision
|
||||
amax = w.abs().flatten(1).max(1).detach()
|
||||
scale = FP8_MAX / (amax + 1e-8)
|
||||
|
|
@ -135,18 +128,8 @@ class FlatTransformer:
|
|||
bsz, seqlen, _ = x.shape
|
||||
new_amaxs, saves = [], []
|
||||
|
||||
x, rrms = rmsnorm(x, self.norm_eps)
|
||||
saves.extend([x, rrms])
|
||||
|
||||
if FP8 and getenv("FUSED_NORM_MUL_QUANTIZE", 1):
|
||||
from extra.amax.cast_amax import fused_mul_quantize_fp8
|
||||
amax_s = amax_xqkv if amax_xqkv is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x.device)
|
||||
x_fp8, x_inv_scale, new_amax_xqkv = fused_mul_quantize_fp8(x, attention_norm, amax_s, FP8_DTYPE)
|
||||
xqkv, *ret = matmul(None, wqkv, w_inv_scale=s_qkv, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax_xqkv)
|
||||
else:
|
||||
x = x * attention_norm
|
||||
xqkv, *ret = matmul(x, wqkv, amax_x=amax_xqkv, w_inv_scale=s_qkv)
|
||||
|
||||
xqkv, normed, rrms, ret = norm_mul_quantize_matmul(x, attention_norm, amax_xqkv, s_qkv, wqkv, self.norm_eps)
|
||||
saves.extend([normed, rrms])
|
||||
new_amaxs.extend(ret[:1])
|
||||
saves.extend(ret[1:] + [xqkv])
|
||||
xqkv = xqkv.reshape(bsz, seqlen, self.n_kv_heads, self.n_rep + 2, self.head_dim)
|
||||
|
|
@ -155,7 +138,7 @@ class FlatTransformer:
|
|||
xv = xqkv[:, :, :, self.n_rep+1].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
|
||||
if FP8: xq, xk, xv = xq.cast(dtypes.bfloat16), xk.cast(dtypes.bfloat16), xv.cast(dtypes.bfloat16)
|
||||
xq, xk, xv = xq.cast(dtypes.bfloat16), xk.cast(dtypes.bfloat16), xv.cast(dtypes.bfloat16)
|
||||
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
|
||||
if getenv("HK_FLASH_ATTENTION"):
|
||||
from extra.thunder.amd.fa import flash_attention
|
||||
|
|
@ -174,28 +157,12 @@ class FlatTransformer:
|
|||
amax_x13=None, amax_x2=None, s_13=None, s_2=None):
|
||||
new_amaxs, saves = [], []
|
||||
|
||||
x, rrms = rmsnorm(x, self.norm_eps)
|
||||
saves.extend([x, rrms])
|
||||
|
||||
if FP8 and getenv("FUSED_NORM_MUL_QUANTIZE", 1):
|
||||
from extra.amax.cast_amax import fused_mul_quantize_fp8
|
||||
amax_s13 = amax_x13 if amax_x13 is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x.device)
|
||||
x_fp8_13, x_inv_scale_13, new_amax_x13 = fused_mul_quantize_fp8(x, ffn_norm, amax_s13, FP8_DTYPE)
|
||||
x_w13, *ret = matmul(None, w13, w_inv_scale=s_13, x_fp8=x_fp8_13, x_scale=x_inv_scale_13, x_new_amax=new_amax_x13)
|
||||
else:
|
||||
x = x * ffn_norm
|
||||
x_w13, *ret = matmul(x, w13, amax_x=amax_x13, w_inv_scale=s_13)
|
||||
x_w13, normed, rrms, ret = norm_mul_quantize_matmul(x, ffn_norm, amax_x13, s_13, w13, self.norm_eps)
|
||||
saves.extend([normed, rrms])
|
||||
new_amaxs.extend(ret[:1])
|
||||
saves.extend(ret[1:] + [x_w13])
|
||||
|
||||
if FP8 and getenv("FUSED_SILU_W13", 1):
|
||||
from extra.amax.cast_amax import fused_quantize_fp8_w13
|
||||
amax_s = amax_x2 if amax_x2 is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x_w13.device)
|
||||
x2_fp8, x2_inv_scale, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_s, FP8_DTYPE)
|
||||
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, x_scale=x2_inv_scale, x_new_amax=new_amax_x2)
|
||||
else:
|
||||
x_w1, x_w3 = x_w13[..., :self.hidden_dim], x_w13[..., self.hidden_dim:]
|
||||
out, *ret = matmul(x_w1.silu() * x_w3, w2, amax_x=amax_x2, w_inv_scale=s_2)
|
||||
out, ret = silu_w13_matmul(x_w13, w2, amax_x2, s_2)
|
||||
new_amaxs.extend(ret[:1])
|
||||
saves.extend(ret[1:] + [out])
|
||||
return (out, *new_amaxs, *saves)
|
||||
|
|
@ -226,41 +193,35 @@ class FlatTransformer:
|
|||
else:
|
||||
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
|
||||
self.wqkv.shard_(device, axis=1).realize() # (n_layers, out, dim) shard out
|
||||
self.wo.shard_(device, axis=2).realize() # (n_layers, dim, in) shard in
|
||||
self.w13.shard_(device, axis=1).realize() # (n_layers, hidden*2, dim) shard out
|
||||
self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in
|
||||
self.wo.shard_(device, axis=2).realize() # (n_layers, dim, in) shard in
|
||||
self.w13.shard_(device, axis=1).realize() # (n_layers, hidden*2, dim) shard out
|
||||
self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in
|
||||
self.attention_norm.shard_(device, axis=None).realize()
|
||||
self.ffn_norm.shard_(device, axis=None).realize()
|
||||
self.norm.weight.shard_(device, axis=None).realize()
|
||||
self.tok_embeddings.weight.shard_(device, axis=0).realize()
|
||||
self.output.shard_(device, axis=1).realize()
|
||||
self.freqs_cis.shard_(device, axis=None).realize()
|
||||
if FP8:
|
||||
for name in self._fp8_amax:
|
||||
for i in range(len(self._fp8_amax[name])):
|
||||
self._fp8_amax[name][i] = self._fp8_amax[name][i].to(device).contiguous().requires_grad_(False)
|
||||
for name in self._fp8_inv_scale:
|
||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().requires_grad_(False)
|
||||
for name in self._fp8_amax:
|
||||
for i in range(len(self._fp8_amax[name])):
|
||||
self._fp8_amax[name][i] = self._fp8_amax[name][i].to(device).contiguous().requires_grad_(False)
|
||||
for name in self._fp8_inv_scale:
|
||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().requires_grad_(False)
|
||||
|
||||
def __call__(self, tokens:Tensor):
|
||||
h = self.tok_embeddings(tokens)
|
||||
freqs_cis = self.freqs_cis.cast(h.dtype)[:, :tokens.shape[1], :, :, :]
|
||||
a = self._fp8_amax if FP8 else None
|
||||
s = self._fp8_inv_scale if FP8 else None
|
||||
amaxs, inv_scales = self._fp8_amax, self._fp8_inv_scale
|
||||
for i in range(self.n_layers):
|
||||
amax_layer = {"amax_xqkv": a["xqkv"][i], "amax_xo": a["xo"][i],
|
||||
"amax_x13": a["x13"][i], "amax_x2": a["x2"][i]} if a else {}
|
||||
scale_layer = {"s_qkv": s["wqkv"][i], "s_o": s["wo"][i],
|
||||
"s_13": s["w13"][i], "s_2": s["w2"][i]} if s else {}
|
||||
h, *ret = self.run_layer(h, freqs_cis,
|
||||
self.attention_norm[i], self.wqkv[i], self.wo[i],
|
||||
self.ffn_norm[i], self.w13[i], self.w2[i],
|
||||
**amax_layer, **scale_layer)
|
||||
if a:
|
||||
amaxs = ret[:5]
|
||||
amax_names = ["xqkv", "xo", "x13", "x2"]
|
||||
for name, new_val in zip(amax_names, amaxs):
|
||||
a[name][i].assign(new_val)
|
||||
amax_xqkv=amaxs["xqkv"][i], amax_xo=amaxs["xo"][i],
|
||||
amax_x13=amaxs["x13"][i], amax_x2=amaxs["x2"][i],
|
||||
s_qkv=inv_scales["wqkv"][i], s_o=inv_scales["wo"][i],
|
||||
s_13=inv_scales["w13"][i], s_2=inv_scales["w2"][i])
|
||||
for name, new_val in zip(["xqkv", "xo", "x13", "x2"], ret[:5]):
|
||||
amaxs[name][i].assign(new_val)
|
||||
|
||||
logits = matmul(self.norm(h).contiguous().contiguous_backward(), self.output[0], fp8=False)[0].contiguous_backward()
|
||||
return logits
|
||||
|
|
|
|||
|
|
@ -1,238 +0,0 @@
|
|||
from __future__ import annotations
|
||||
import functools, pathlib
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCCCompiler
|
||||
|
||||
FP8_MAX = 448.0
|
||||
NUM_WG, THREADS_PER_WG = 1024, 256
|
||||
|
||||
def _compile(cpp_name:str, n_elems:int, hidden:int):
|
||||
src = (pathlib.Path(__file__).parent/cpp_name).read_text()
|
||||
defines = [f"-DN_ELEMS={n_elems}", f"-DHIDDEN={hidden}", f"-DNUM_WG={NUM_WG}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"]
|
||||
return src, HIPCCCompiler("gfx950", ["-std=c++20", "-ffast-math", *defines]).compile_cached(src)
|
||||
|
||||
def _shard_shape(shape:tuple, axis:int, ndev:int) -> list:
|
||||
s = list(shape)
|
||||
s[axis] //= ndev
|
||||
return s
|
||||
|
||||
def _scalar_amax(amax_buf:Tensor) -> Tensor:
|
||||
if isinstance(amax_buf.device, tuple):
|
||||
from examples.mlperf.models.flat_llama import _local_abs_max
|
||||
return _local_abs_max(amax_buf).detach()
|
||||
return amax_buf.max().detach()
|
||||
|
||||
|
||||
# **** fused silu*mul -> fp8 cast + amax (w13 layout) ****
|
||||
|
||||
@functools.cache
|
||||
def _custom_fused_bwd_w13(grad_xw13:UOp, xw13:UOp, grad_x2:UOp, amax_state:UOp, dname:str) -> UOp:
|
||||
hidden = xw13.shape[2] // 2
|
||||
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
|
||||
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
|
||||
mem = n_elems * 2 * 5
|
||||
sink = UOp.sink(grad_xw13.base, xw13.base, grad_x2.base, amax_state.base, threads, workgroups,
|
||||
arg=KernelInfo(f"fused_silu_mul_bwd_w13_{n_elems}", estimates=Estimates(ops=8*n_elems, mem=mem)))
|
||||
src, lib = _compile("cast_amax_bwd_w13.cpp", n_elems, hidden)
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
@functools.cache
|
||||
def _custom_fused_cast_amax_w13(fp8_out:UOp, amax_buf:UOp, xw13:UOp, amax_state:UOp, dname:str) -> UOp:
|
||||
hidden = xw13.shape[2] // 2
|
||||
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
|
||||
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
|
||||
mem = n_elems * 2 * 2 + n_elems + NUM_WG * 2
|
||||
sink = UOp.sink(fp8_out.base, amax_buf.base, xw13.base, amax_state.base, threads, workgroups,
|
||||
arg=KernelInfo(f"fused_silu_mul_cast_amax_w13_{n_elems}", estimates=Estimates(ops=5*n_elems, mem=mem)))
|
||||
src, lib = _compile("cast_amax_fwd_w13.cpp", n_elems, hidden)
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
|
||||
# NOTE: inputs are (fp8_out, amax_buf, xw13, amax_state); grad for xw13 only
|
||||
_, _, xw13, amax_state = kernel.src[1:]
|
||||
device = xw13.device
|
||||
if isinstance(device, tuple):
|
||||
axis, ndev = xw13.axis, len(device)
|
||||
assert axis in (0, 1), f"unsupported sharding axis={axis}"
|
||||
grad_xw13 = Tensor(Tensor.invalids(*_shard_shape(xw13.shape, axis, ndev), dtype=dtypes.bfloat16,
|
||||
device=device).uop.multi(axis), device=device)
|
||||
dname = device[0].split(":")[0]
|
||||
else:
|
||||
grad_xw13 = Tensor.invalids(*xw13.shape, dtype=dtypes.bfloat16, device=device)
|
||||
dname = device.split(":")[0] if isinstance(device, str) else device
|
||||
grad_x2_t = Tensor(gradient, device=device).cast(dtypes.bfloat16)
|
||||
fxn = functools.partial(_custom_fused_bwd_w13, dname=dname)
|
||||
grad_xw13, *_ = Tensor.custom_kernel(grad_xw13, Tensor(xw13, device=device), grad_x2_t,
|
||||
Tensor(amax_state, device=device), fxn=fxn)
|
||||
return (None, None, grad_xw13.uop, None)
|
||||
|
||||
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, inv_scale, new_amax)
|
||||
assert xw13.dtype == dtypes.bfloat16, f"expected bf16, got {xw13.dtype}"
|
||||
MBS, SEQ, H2 = xw13.shape
|
||||
assert H2 % 2 == 0, f"w13 last-axis must be even, got {H2}"
|
||||
HIDDEN = H2 // 2
|
||||
if isinstance(xw13.device, tuple):
|
||||
axis, ndev = xw13.uop.axis, len(xw13.device)
|
||||
assert axis in (0, 1), f"unsupported sharding axis={axis}"
|
||||
fp8_out = Tensor(Tensor.invalids(*_shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype,
|
||||
device=xw13.device).uop.multi(axis), device=xw13.device)
|
||||
amax_buf = Tensor(Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device).uop.multi(0),
|
||||
device=xw13.device)
|
||||
dname = xw13.device[0].split(":")[0]
|
||||
else:
|
||||
fp8_out = Tensor.invalids(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=xw13.device)
|
||||
amax_buf = Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device)
|
||||
dname = xw13.device.split(":")[0] if isinstance(xw13.device, str) else xw13.device
|
||||
fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname)
|
||||
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, fxn=fxn,
|
||||
grad_fxn=_fused_quantize_bwd_w13)
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
return fp8_out, inv_scale, _scalar_amax(amax_buf)
|
||||
|
||||
# **** fused (x * weight) -> fp8 cast + amax (norm-mul-quantize) ****
|
||||
|
||||
@functools.cache
|
||||
def _custom_mul_quantize_fp8(fp8_out:UOp, amax_buf:UOp, x:UOp, weight:UOp, amax_state:UOp, dname:str) -> UOp:
|
||||
MBS, SEQ, HIDDEN = x.shape
|
||||
n_elems = MBS * SEQ * HIDDEN
|
||||
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
|
||||
mem = n_elems * 2 + HIDDEN * 2 + n_elems + NUM_WG * 2
|
||||
sink = UOp.sink(fp8_out.base, amax_buf.base, x.base, weight.base, amax_state.base, threads, workgroups,
|
||||
arg=KernelInfo(f"fused_mul_quantize_fp8_{n_elems}_h{HIDDEN}", estimates=Estimates(ops=3*n_elems, mem=mem)))
|
||||
src, lib = _compile("fused_mul_quantize_fp8.cpp", n_elems, HIDDEN)
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
def _fused_mul_quantize_fp8_bwd(gradient:UOp, kernel:UOp):
|
||||
# NOTE: inputs are (fp8_out, amax_buf, x, weight, amax_state); grads for x and weight
|
||||
_, _, x_u, weight_u, amax_state_u = kernel.src[1:]
|
||||
device = x_u.device
|
||||
grad_t = Tensor(gradient, device=device).cast(dtypes.bfloat16)
|
||||
x_t, weight_t = Tensor(x_u, device=device), Tensor(weight_u, device=device)
|
||||
scale = FP8_MAX / (Tensor(amax_state_u, device=device).float() + 1e-8)
|
||||
grad_scaled = grad_t.float() * scale
|
||||
# NOTE: grad_x stays bf16 to avoid CSE materializing a (MBS, SEQ, HIDDEN) fp32 intermediate
|
||||
grad_x = (grad_scaled * weight_t.float()).cast(dtypes.bfloat16)
|
||||
grad_weight = (grad_scaled * x_t.float()).sum(axis=(0, 1)).cast(dtypes.bfloat16)
|
||||
return (None, None, grad_x.uop, grad_weight.uop, None)
|
||||
|
||||
def fused_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# NOTE: (x * weight) -> fp8 + amax, delayed scaling. Returns (fp8, inv_scale, new_amax)
|
||||
assert x.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16
|
||||
assert x.shape[-1] == weight.shape[-1], f"HIDDEN mismatch: x={x.shape}, weight={weight.shape}"
|
||||
MBS, SEQ, HIDDEN = x.shape
|
||||
if isinstance(x.device, tuple):
|
||||
axis, ndev = x.uop.axis, len(x.device)
|
||||
assert axis in (0, 1), f"unsupported sharding axis={axis}"
|
||||
fp8_out = Tensor(Tensor.invalids(*_shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype,
|
||||
device=x.device).uop.multi(axis), device=x.device)
|
||||
amax_buf = Tensor(Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=x.device).uop.multi(0), device=x.device)
|
||||
dname = x.device[0].split(":")[0]
|
||||
else:
|
||||
fp8_out = Tensor.invalids(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=x.device)
|
||||
amax_buf = Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=x.device)
|
||||
dname = x.device.split(":")[0] if isinstance(x.device, str) else x.device
|
||||
fxn = functools.partial(_custom_mul_quantize_fp8, dname=dname)
|
||||
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, x, weight, amax_state, fxn=fxn,
|
||||
grad_fxn=_fused_mul_quantize_fp8_bwd)
|
||||
new_amax = _scalar_amax(amax_buf)
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
return fp8_out, inv_scale, new_amax
|
||||
|
||||
# **** fused ce loss ****
|
||||
|
||||
@functools.cache
|
||||
def _custom_fused_ce_loss_fwd(loss_out:UOp, max_out:UOp, lse_out:UOp, logits:UOp, targets:UOp,
|
||||
dname:str, vocab:int, rows:int, label_smoothing:float) -> UOp:
|
||||
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(rows, "gidx0")
|
||||
mem = rows * vocab * 2 + rows * 12 + rows * 4
|
||||
sink = UOp.sink(loss_out.base, max_out.base, lse_out.base, logits.base, targets.base,
|
||||
threads, workgroups,
|
||||
arg=KernelInfo(f"fused_ce_loss_fwd", estimates=Estimates(ops=6*rows*vocab, mem=mem)))
|
||||
src = (pathlib.Path(__file__).parent/"fused_ce_loss.cpp").read_text()
|
||||
defines = [f"-DVOCAB={vocab}", f"-DTHREADS_PER_WG={THREADS_PER_WG}",
|
||||
f"-DLABEL_SMOOTHING={label_smoothing}f"]
|
||||
lib = HIPCCCompiler("gfx950", ["-std=c++20", "-ffast-math", *defines]).compile_cached(src)
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
@functools.cache
|
||||
def _custom_fused_ce_loss_bwd(d_logits:UOp, logits:UOp, lse:UOp, targets:UOp, scale:UOp,
|
||||
dname:str, vocab:int, rows:int, label_smoothing:float) -> UOp:
|
||||
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(rows, "gidx0")
|
||||
mem = rows * vocab * 4 + rows * 8 + 4
|
||||
sink = UOp.sink(d_logits.base, logits.base, lse.base, targets.base, scale.base,
|
||||
threads, workgroups,
|
||||
arg=KernelInfo(f"fused_ce_loss_bwd", estimates=Estimates(ops=4*rows*vocab, mem=mem)))
|
||||
src = (pathlib.Path(__file__).parent/"fused_ce_loss_bwd.cpp").read_text()
|
||||
defines = [f"-DVOCAB={vocab}", f"-DTHREADS_PER_WG={THREADS_PER_WG}",
|
||||
f"-DLABEL_SMOOTHING={label_smoothing}f"]
|
||||
lib = HIPCCCompiler("gfx950", ["-std=c++20", "-ffast-math", *defines]).compile_cached(src)
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
def _fused_ce_loss_bwd(gradient:UOp, kernel:UOp):
|
||||
# NOTE: forward inputs are (loss_out, max_out, lse_out, logits, targets)
|
||||
# gradient is the upstream grad w.r.t. per-row loss (shape: (rows,) fp32)
|
||||
_, _, lse_u, logits_u, targets_u = kernel.src[1:]
|
||||
device = logits_u.device
|
||||
rows_vocab = logits_u.shape # (rows, VOCAB) after reshape
|
||||
rows, VOCAB = rows_vocab
|
||||
if isinstance(device, tuple):
|
||||
axis = logits_u.axis
|
||||
ndev = len(device)
|
||||
d_logits = Tensor(Tensor.invalids(rows // ndev, VOCAB, dtype=dtypes.bfloat16, device=device).uop.multi(axis), device=device)
|
||||
dname = device[0].split(":")[0]
|
||||
rows_per_dev = rows // ndev
|
||||
else:
|
||||
d_logits = Tensor.invalids(rows, VOCAB, dtype=dtypes.bfloat16, device=device)
|
||||
dname = device.split(":")[0] if isinstance(device, str) else device
|
||||
rows_per_dev = rows
|
||||
grad_t = Tensor(gradient, device=device).float().reshape(-1) # (rows,) fp32
|
||||
# NOTE: .mean() backward gives same grad per row (1/N), so broadcast is safe; take scalar
|
||||
scale = grad_t[0:1].contiguous()
|
||||
logits_t = Tensor(logits_u.after(kernel), device=device)
|
||||
lse_t = Tensor(lse_u.after(kernel), device=device)
|
||||
targets_t = Tensor(targets_u, device=device)
|
||||
fxn = functools.partial(_custom_fused_ce_loss_bwd, dname=dname, vocab=VOCAB, rows=rows_per_dev, label_smoothing=0.1)
|
||||
d_logits, *_ = Tensor.custom_kernel(d_logits, logits_t, lse_t, targets_t, scale, fxn=fxn)
|
||||
return (None, None, None, d_logits.uop, None)
|
||||
|
||||
def fused_ce_loss(logits:Tensor, targets:Tensor, label_smoothing:float=0.1) -> Tensor:
|
||||
# NOTE: fused sparse_categorical_crossentropy with label smoothing, returns mean loss scalar
|
||||
assert logits.dtype == dtypes.bfloat16, f"expected bf16, got {logits.dtype}"
|
||||
assert logits.ndim == 3, f"expected (MBS, SEQ, VOCAB), got {logits.shape}"
|
||||
MBS, SEQ, VOCAB = logits.shape
|
||||
rows = MBS * SEQ
|
||||
if isinstance(logits.device, tuple):
|
||||
axis = logits.uop.axis
|
||||
assert axis in (0, 1), f"unsupported sharding axis={axis} for CE loss"
|
||||
ndev = len(logits.device)
|
||||
loss_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
|
||||
device=logits.device)
|
||||
max_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
|
||||
device=logits.device)
|
||||
lse_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
|
||||
device=logits.device)
|
||||
dname = logits.device[0].split(":")[0]
|
||||
rows_per_dev = rows // ndev
|
||||
else:
|
||||
loss_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
|
||||
max_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
|
||||
lse_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
|
||||
dname = logits.device.split(":")[0] if isinstance(logits.device, str) else logits.device
|
||||
rows_per_dev = rows
|
||||
logits_flat = logits.reshape(rows, VOCAB)
|
||||
targets_flat = targets.reshape(-1).cast(dtypes.int32)
|
||||
fxn = functools.partial(_custom_fused_ce_loss_fwd, dname=dname, vocab=VOCAB, rows=rows_per_dev,
|
||||
label_smoothing=label_smoothing)
|
||||
loss_out, max_out, lse_out, *_ = Tensor.custom_kernel(
|
||||
loss_out, max_out, lse_out, logits_flat, targets_flat,
|
||||
fxn=fxn, grad_fxn=_fused_ce_loss_bwd)
|
||||
return loss_out.mean()
|
||||
|
||||
35
extra/llama_kernels/__init__.py
Normal file
35
extra/llama_kernels/__init__.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from __future__ import annotations
|
||||
import functools, pathlib
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCCCompiler
|
||||
|
||||
FP8_MAX = 448.0
|
||||
NUM_WG, THREADS_PER_WG = 1024, 256
|
||||
|
||||
# per-device abs max without allreduce
|
||||
@functools.cache
|
||||
def _local_abs_max_fxn(x_p, device):
|
||||
x = Tensor(x_p, device=device)
|
||||
inner = Tensor(x.uop.src[0]) if x.uop.op is Ops.MULTI else x
|
||||
return (inner.abs().max(),)
|
||||
|
||||
def local_abs_max(x:Tensor) -> Tensor:
|
||||
param = x.as_param(0)
|
||||
fxn = _local_abs_max_fxn(param.uop, x.device)
|
||||
return Tensor(fxn[0].uop.call(x.uop).gettuple(0))
|
||||
|
||||
def scalar_amax(amax_buf:Tensor) -> Tensor:
|
||||
if isinstance(amax_buf.device, tuple):
|
||||
return local_abs_max(amax_buf).detach()
|
||||
return amax_buf.max().detach()
|
||||
|
||||
def shard_shape(shape:tuple, axis:int, ndev:int) -> list:
|
||||
s = list(shape)
|
||||
s[axis] //= ndev
|
||||
return s
|
||||
|
||||
def compile_cpp(cpp_dir:pathlib.Path, cpp_name:str, n_elems:int, hidden:int):
|
||||
src = (cpp_dir/cpp_name).read_text()
|
||||
defines = [f"-DN_ELEMS={n_elems}", f"-DHIDDEN={hidden}", f"-DNUM_WG={NUM_WG}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"]
|
||||
return src, HIPCCCompiler("gfx950", ["-std=c++20", "-ffast-math", *defines]).compile_cached(src)
|
||||
73
extra/llama_kernels/cast_amax/__init__.py
Normal file
73
extra/llama_kernels/cast_amax/__init__.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
from __future__ import annotations
|
||||
import functools, pathlib
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.renderer import Estimates
|
||||
from extra.llama_kernels import FP8_MAX, NUM_WG, THREADS_PER_WG, compile_cpp, shard_shape, scalar_amax
|
||||
|
||||
@functools.cache
|
||||
def _custom_fused_bwd_w13(grad_xw13:UOp, xw13:UOp, grad_x2:UOp, amax_state:UOp, dname:str) -> UOp:
|
||||
hidden = xw13.shape[2] // 2
|
||||
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
|
||||
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
|
||||
mem = n_elems * 2 * 5
|
||||
sink = UOp.sink(grad_xw13.base, xw13.base, grad_x2.base, amax_state.base, threads, workgroups,
|
||||
arg=KernelInfo(f"fused_silu_mul_bwd_w13_{n_elems}", estimates=Estimates(ops=8*n_elems, mem=mem)))
|
||||
src, lib = compile_cpp(pathlib.Path(__file__).parent, "cast_amax_bwd_w13.cpp", n_elems, hidden)
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
@functools.cache
|
||||
def _custom_fused_cast_amax_w13(fp8_out:UOp, amax_buf:UOp, xw13:UOp, amax_state:UOp, dname:str) -> UOp:
|
||||
hidden = xw13.shape[2] // 2
|
||||
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
|
||||
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
|
||||
mem = n_elems * 2 * 2 + n_elems + NUM_WG * 2
|
||||
sink = UOp.sink(fp8_out.base, amax_buf.base, xw13.base, amax_state.base, threads, workgroups,
|
||||
arg=KernelInfo(f"fused_silu_mul_cast_amax_w13_{n_elems}", estimates=Estimates(ops=5*n_elems, mem=mem)))
|
||||
src, lib = compile_cpp(pathlib.Path(__file__).parent, "cast_amax_fwd_w13.cpp", n_elems, hidden)
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
|
||||
# NOTE: inputs are (fp8_out, amax_buf, xw13, amax_state); grad for xw13 only
|
||||
_, _, xw13, amax_state = kernel.src[1:]
|
||||
device = xw13.device
|
||||
if isinstance(device, tuple):
|
||||
axis, ndev = xw13.axis, len(device)
|
||||
assert axis in (0, 1), f"unsupported sharding axis={axis}"
|
||||
grad_xw13 = Tensor(Tensor.invalids(*shard_shape(xw13.shape, axis, ndev), dtype=dtypes.bfloat16,
|
||||
device=device).uop.multi(axis), device=device)
|
||||
dname = device[0].split(":")[0]
|
||||
else:
|
||||
grad_xw13 = Tensor.invalids(*xw13.shape, dtype=dtypes.bfloat16, device=device)
|
||||
dname = device.split(":")[0] if isinstance(device, str) else device
|
||||
grad_x2_t = Tensor(gradient, device=device).cast(dtypes.bfloat16)
|
||||
fxn = functools.partial(_custom_fused_bwd_w13, dname=dname)
|
||||
grad_xw13, *_ = Tensor.custom_kernel(grad_xw13, Tensor(xw13, device=device), grad_x2_t,
|
||||
Tensor(amax_state, device=device), fxn=fxn)
|
||||
return (None, None, grad_xw13.uop, None)
|
||||
|
||||
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, inv_scale, new_amax)
|
||||
assert xw13.dtype == dtypes.bfloat16, f"expected bf16, got {xw13.dtype}"
|
||||
MBS, SEQ, H2 = xw13.shape
|
||||
assert H2 % 2 == 0, f"w13 last-axis must be even, got {H2}"
|
||||
HIDDEN = H2 // 2
|
||||
if isinstance(xw13.device, tuple):
|
||||
axis, ndev = xw13.uop.axis, len(xw13.device)
|
||||
assert axis in (0, 1), f"unsupported sharding axis={axis}"
|
||||
fp8_out = Tensor(Tensor.invalids(*shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype,
|
||||
device=xw13.device).uop.multi(axis), device=xw13.device)
|
||||
amax_buf = Tensor(Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device).uop.multi(0),
|
||||
device=xw13.device)
|
||||
dname = xw13.device[0].split(":")[0]
|
||||
else:
|
||||
fp8_out = Tensor.invalids(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=xw13.device)
|
||||
amax_buf = Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device)
|
||||
dname = xw13.device.split(":")[0] if isinstance(xw13.device, str) else xw13.device
|
||||
fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname)
|
||||
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, fxn=fxn,
|
||||
grad_fxn=_fused_quantize_bwd_w13)
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
return fp8_out, inv_scale, scalar_amax(amax_buf)
|
||||
98
extra/llama_kernels/fused_ce/__init__.py
Normal file
98
extra/llama_kernels/fused_ce/__init__.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
from __future__ import annotations
|
||||
import functools, pathlib
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCCCompiler
|
||||
|
||||
THREADS_PER_WG = 256
|
||||
|
||||
@functools.cache
|
||||
def _custom_fused_ce_loss_fwd(loss_out:UOp, max_out:UOp, lse_out:UOp, logits:UOp, targets:UOp,
|
||||
dname:str, vocab:int, rows:int, label_smoothing:float) -> UOp:
|
||||
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(rows, "gidx0")
|
||||
mem = rows * vocab * 2 + rows * 12 + rows * 4
|
||||
sink = UOp.sink(loss_out.base, max_out.base, lse_out.base, logits.base, targets.base,
|
||||
threads, workgroups,
|
||||
arg=KernelInfo(f"fused_ce_loss_fwd", estimates=Estimates(ops=6*rows*vocab, mem=mem)))
|
||||
src = (pathlib.Path(__file__).parent/"fused_ce_loss.cpp").read_text()
|
||||
defines = [f"-DVOCAB={vocab}", f"-DTHREADS_PER_WG={THREADS_PER_WG}",
|
||||
f"-DLABEL_SMOOTHING={label_smoothing}f"]
|
||||
lib = HIPCCCompiler("gfx950", ["-std=c++20", "-ffast-math", *defines]).compile_cached(src)
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
@functools.cache
|
||||
def _custom_fused_ce_loss_bwd(d_logits:UOp, logits:UOp, lse:UOp, targets:UOp, scale:UOp,
|
||||
dname:str, vocab:int, rows:int, label_smoothing:float) -> UOp:
|
||||
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(rows, "gidx0")
|
||||
mem = rows * vocab * 4 + rows * 8 + 4
|
||||
sink = UOp.sink(d_logits.base, logits.base, lse.base, targets.base, scale.base,
|
||||
threads, workgroups,
|
||||
arg=KernelInfo(f"fused_ce_loss_bwd", estimates=Estimates(ops=4*rows*vocab, mem=mem)))
|
||||
src = (pathlib.Path(__file__).parent/"fused_ce_loss_bwd.cpp").read_text()
|
||||
defines = [f"-DVOCAB={vocab}", f"-DTHREADS_PER_WG={THREADS_PER_WG}",
|
||||
f"-DLABEL_SMOOTHING={label_smoothing}f"]
|
||||
lib = HIPCCCompiler("gfx950", ["-std=c++20", "-ffast-math", *defines]).compile_cached(src)
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
def _fused_ce_loss_bwd(gradient:UOp, kernel:UOp):
|
||||
# NOTE: forward inputs are (loss_out, max_out, lse_out, logits, targets)
|
||||
# gradient is the upstream grad w.r.t. per-row loss (shape: (rows,) fp32)
|
||||
_, _, lse_u, logits_u, targets_u = kernel.src[1:]
|
||||
device = logits_u.device
|
||||
rows_vocab = logits_u.shape # (rows, VOCAB) after reshape
|
||||
rows, VOCAB = rows_vocab
|
||||
if isinstance(device, tuple):
|
||||
axis = logits_u.axis
|
||||
ndev = len(device)
|
||||
d_logits = Tensor(Tensor.invalids(rows // ndev, VOCAB, dtype=dtypes.bfloat16, device=device).uop.multi(axis), device=device)
|
||||
dname = device[0].split(":")[0]
|
||||
rows_per_dev = rows // ndev
|
||||
else:
|
||||
d_logits = Tensor.invalids(rows, VOCAB, dtype=dtypes.bfloat16, device=device)
|
||||
dname = device.split(":")[0] if isinstance(device, str) else device
|
||||
rows_per_dev = rows
|
||||
grad_t = Tensor(gradient, device=device).float().reshape(-1) # (rows,) fp32
|
||||
# NOTE: .mean() backward gives same grad per row (1/N), so broadcast is safe; take scalar
|
||||
scale = grad_t[0:1].contiguous()
|
||||
logits_t = Tensor(logits_u.after(kernel), device=device)
|
||||
lse_t = Tensor(lse_u.after(kernel), device=device)
|
||||
targets_t = Tensor(targets_u, device=device)
|
||||
fxn = functools.partial(_custom_fused_ce_loss_bwd, dname=dname, vocab=VOCAB, rows=rows_per_dev, label_smoothing=label_smoothing)
|
||||
d_logits, *_ = Tensor.custom_kernel(d_logits, logits_t, lse_t, targets_t, scale, fxn=fxn)
|
||||
return (None, None, None, d_logits.uop, None)
|
||||
|
||||
def fused_ce_loss(logits:Tensor, targets:Tensor, label_smoothing:float=0.1) -> Tensor:
|
||||
# NOTE: fused sparse_categorical_crossentropy with label smoothing, returns mean loss scalar
|
||||
assert logits.dtype == dtypes.bfloat16, f"expected bf16, got {logits.dtype}"
|
||||
assert logits.ndim == 3, f"expected (MBS, SEQ, VOCAB), got {logits.shape}"
|
||||
MBS, SEQ, VOCAB = logits.shape
|
||||
rows = MBS * SEQ
|
||||
if isinstance(logits.device, tuple):
|
||||
axis = logits.uop.axis
|
||||
assert axis in (0, 1), f"unsupported sharding axis={axis} for CE loss"
|
||||
ndev = len(logits.device)
|
||||
loss_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
|
||||
device=logits.device)
|
||||
max_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
|
||||
device=logits.device)
|
||||
lse_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
|
||||
device=logits.device)
|
||||
dname = logits.device[0].split(":")[0]
|
||||
rows_per_dev = rows // ndev
|
||||
else:
|
||||
loss_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
|
||||
max_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
|
||||
lse_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
|
||||
dname = logits.device.split(":")[0] if isinstance(logits.device, str) else logits.device
|
||||
rows_per_dev = rows
|
||||
logits_flat = logits.reshape(rows, VOCAB)
|
||||
targets_flat = targets.reshape(-1).cast(dtypes.int32)
|
||||
fxn = functools.partial(_custom_fused_ce_loss_fwd, dname=dname, vocab=VOCAB, rows=rows_per_dev,
|
||||
label_smoothing=label_smoothing)
|
||||
loss_out, max_out, lse_out, *_ = Tensor.custom_kernel(
|
||||
loss_out, max_out, lse_out, logits_flat, targets_flat,
|
||||
fxn=fxn, grad_fxn=_fused_ce_loss_bwd)
|
||||
return loss_out.mean()
|
||||
54
extra/llama_kernels/fused_mul_quantize_fp8/__init__.py
Normal file
54
extra/llama_kernels/fused_mul_quantize_fp8/__init__.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from __future__ import annotations
|
||||
import functools, pathlib
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.renderer import Estimates
|
||||
from extra.llama_kernels import FP8_MAX, NUM_WG, THREADS_PER_WG, compile_cpp, shard_shape, scalar_amax
|
||||
|
||||
@functools.cache
|
||||
def _custom_mul_quantize_fp8(fp8_out:UOp, amax_buf:UOp, x:UOp, weight:UOp, amax_state:UOp, dname:str) -> UOp:
|
||||
MBS, SEQ, HIDDEN = x.shape
|
||||
n_elems = MBS * SEQ * HIDDEN
|
||||
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
|
||||
mem = n_elems * 2 + HIDDEN * 2 + n_elems + NUM_WG * 2
|
||||
sink = UOp.sink(fp8_out.base, amax_buf.base, x.base, weight.base, amax_state.base, threads, workgroups,
|
||||
arg=KernelInfo(f"fused_mul_quantize_fp8_{n_elems}_h{HIDDEN}", estimates=Estimates(ops=3*n_elems, mem=mem)))
|
||||
src, lib = compile_cpp(pathlib.Path(__file__).parent, "fused_mul_quantize_fp8.cpp", n_elems, HIDDEN)
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
def _fused_mul_quantize_fp8_bwd(gradient:UOp, kernel:UOp):
|
||||
# NOTE: inputs are (fp8_out, amax_buf, x, weight, amax_state); grads for x and weight
|
||||
_, _, x_u, weight_u, amax_state_u = kernel.src[1:]
|
||||
device = x_u.device
|
||||
grad_t = Tensor(gradient, device=device).cast(dtypes.bfloat16)
|
||||
x_t, weight_t = Tensor(x_u, device=device), Tensor(weight_u, device=device)
|
||||
scale = FP8_MAX / (Tensor(amax_state_u, device=device).float() + 1e-8)
|
||||
grad_scaled = grad_t.float() * scale
|
||||
# NOTE: grad_x stays bf16 to avoid CSE materializing a (MBS, SEQ, HIDDEN) fp32 intermediate
|
||||
grad_x = (grad_scaled * weight_t.float()).cast(dtypes.bfloat16)
|
||||
grad_weight = (grad_scaled * x_t.float()).sum(axis=(0, 1)).cast(dtypes.bfloat16)
|
||||
return (None, None, grad_x.uop, grad_weight.uop, None)
|
||||
|
||||
def fused_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# NOTE: (x * weight) -> fp8 + amax, delayed scaling. Returns (fp8, inv_scale, new_amax)
|
||||
assert x.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16
|
||||
assert x.shape[-1] == weight.shape[-1], f"HIDDEN mismatch: x={x.shape}, weight={weight.shape}"
|
||||
MBS, SEQ, HIDDEN = x.shape
|
||||
if isinstance(x.device, tuple):
|
||||
axis, ndev = x.uop.axis, len(x.device)
|
||||
assert axis in (0, 1), f"unsupported sharding axis={axis}"
|
||||
fp8_out = Tensor(Tensor.invalids(*shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype,
|
||||
device=x.device).uop.multi(axis), device=x.device)
|
||||
amax_buf = Tensor(Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=x.device).uop.multi(0), device=x.device)
|
||||
dname = x.device[0].split(":")[0]
|
||||
else:
|
||||
fp8_out = Tensor.invalids(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=x.device)
|
||||
amax_buf = Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=x.device)
|
||||
dname = x.device.split(":")[0] if isinstance(x.device, str) else x.device
|
||||
fxn = functools.partial(_custom_mul_quantize_fp8, dname=dname)
|
||||
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, x, weight, amax_state, fxn=fxn,
|
||||
grad_fxn=_fused_mul_quantize_fp8_bwd)
|
||||
new_amax = scalar_amax(amax_buf)
|
||||
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
|
||||
return fp8_out, inv_scale, new_amax
|
||||
24
extra/llama_kernels/rmsnorm/__init__.py
Normal file
24
extra/llama_kernels/rmsnorm/__init__.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
from __future__ import annotations
|
||||
import functools
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
def rmsnorm_fwd(x_in:Tensor, eps:float) -> tuple[Tensor, Tensor]:
|
||||
x = x_in.float()
|
||||
rrms = (x.square().mean(-1, keepdim=True) + eps).rsqrt()
|
||||
return (x * rrms).cast(x_in.dtype), rrms
|
||||
|
||||
@functools.cache
|
||||
def _rmsnorm_fwd_fxn(x_in_p, eps, device):
|
||||
return rmsnorm_fwd(Tensor(x_in_p, device=device), eps)
|
||||
|
||||
def _rmsnorm_bwd(grad:UOp, call:UOp) -> tuple:
|
||||
x_normed = Tensor(call.gettuple(0)).float()
|
||||
do_float = Tensor(grad).float()
|
||||
d_x = Tensor(call.gettuple(1)) * (do_float - x_normed * (do_float * x_normed).mean(-1, keepdim=True))
|
||||
return (d_x.cast(call.src[1].dtype).uop,)
|
||||
|
||||
def rmsnorm(x_in:Tensor, eps:float) -> tuple[Tensor, Tensor]:
|
||||
fxn = _rmsnorm_fwd_fxn(x_in.as_param(0).uop, eps, x_in.device)
|
||||
call = UOp.maketuple(fxn[0].uop, fxn[1].uop).call(x_in.uop, grad_fxn=_rmsnorm_bwd)
|
||||
return Tensor(call.gettuple(0)), Tensor(call.gettuple(1))
|
||||
Loading…
Add table
Add a link
Reference in a new issue