llama: move llama kernels to llama_kernels (#15952)

This commit is contained in:
wozeparrot 2026-04-28 13:48:53 +08:00 committed by GitHub
commit 5e861cd2c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 364 additions and 358 deletions

View file

@ -1282,7 +1282,7 @@ def train_bert():
previous_step = i
def train_llama3():
from examples.mlperf.models.flat_llama import FlatTransformer, apply_grad, FP8, FP8_DTYPE
from examples.mlperf.models.flat_llama import FlatTransformer, apply_grad, FP8_DTYPE
from examples.llama3 import MODEL_PARAMS
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
from examples.mlperf.optim import GradAccClipAdamW
@ -1432,18 +1432,17 @@ def train_llama3():
print(f"loading optim checkpoint from {fn}")
load_state_dict(scheduler, safe_load(fn), realize=False)
fp8_amax = [t for ts in model._fp8_amax.values() for t in ts] if FP8 else []
fp8_inv_scales = list(model._fp8_inv_scale.values()) if FP8 else []
fp8_amax = [t for ts in model._fp8_amax.values() for t in ts]
fp8_inv_scales = list(model._fp8_inv_scale.values())
if FP8:
from tinygrad.nn.state import get_state_dict
model_state = get_state_dict(model)
for wname in ["wqkv", "wo", "w13", "w2"]:
w = model_state[wname]
w._inv_scale = model._fp8_inv_scale[wname]
if optim.master_params:
idx = next(j for j, p in enumerate(optim.params) if p is w)
optim.master_params[idx].assign((optim.master_params[idx] * w._inv_scale.reshape(-1, *([1]*(w.ndim-1)))).contiguous())
from tinygrad.nn.state import get_state_dict
model_state = get_state_dict(model)
for wname in ["wqkv", "wo", "w13", "w2"]:
w = model_state[wname]
w._inv_scale = model._fp8_inv_scale[wname]
if optim.master_params:
idx = next(j for j, p in enumerate(optim.params) if p is w)
optim.master_params[idx].assign((optim.master_params[idx] * w._inv_scale.reshape(-1, *([1]*(w.ndim-1)))).contiguous())
@TinyJit
def minibatch(tokens:Tensor):
@ -1452,7 +1451,7 @@ def train_llama3():
if not is_sharding: tokens = tokens.to(None)
logits:Tensor = model(tokens[:, :-1])
if getenv("FAST_CE", 0):
from extra.amax.cast_amax import fused_ce_loss
from extra.llama_kernels.fused_ce import fused_ce_loss
loss = fused_ce_loss(logits.cast(dtypes.bfloat16), tokens[:, 1:], label_smoothing=0.0)
else:
loss = vocab_mask.where(-1e9, logits).sparse_categorical_crossentropy(tokens[:, 1:])
@ -1559,7 +1558,7 @@ def train_llama3():
mem_gb = GlobalCounters.mem_used / 1e9
gflops = GlobalCounters.global_ops / 1e9 / dev_time
mfu = ((6 * num_params * SEQLEN * GBS) / (dev_time * device_count * (4.6e15 if FP8 else 2.3e15))) * 100
mfu = ((6 * num_params * SEQLEN * GBS) / (dev_time * device_count * 4.6e15)) * 100
tqdm.write(
f"{i:5} {step_time:.3f} s step, {gbs_time:.3f} s gbs, {optim_time:.3f} s optim, {data_time:.3f} s data, {loss:.4f} loss, " \
f"{lr:.12f} LR, {grad_norm:.6f} grad_norm, {mem_gb:.2f} GB used, {gflops:9.2f} GFLOPS, {mfu:5.2f}% MFU")

View file

@ -1,4 +1,4 @@
import math, os, functools
import math, os
if __name__ == "__main__":
os.environ["DEFAULT_FLOAT"] = "bfloat16"
os.environ["OPTIM_DTYPE"] = "bfloat16"
@ -16,65 +16,60 @@ from tinygrad import Tensor, nn, function, getenv, dtypes, TinyJit
from tinygrad.helpers import Timing, colored, GlobalCounters, profile_marker
from tinygrad.uop.ops import Ops, UOp
from extra.models.llama import apply_rotary_emb, precompute_freqs_cis
from extra.llama_kernels.rmsnorm import rmsnorm
from extra.llama_kernels import FP8_MAX, local_abs_max
FP8 = getenv("FP8", 0)
ASM_GEMM = getenv("ASM_GEMM", 0)
FP8_DTYPE = dtypes.fp8e4m3
FP8_GRAD_DTYPE = dtypes.fp8e5m2
FP8_MAX = 448.0
# per-device abs max without allreduce (matches TE delayed scaling behavior)
@functools.cache
def _local_abs_max_fxn(x_p, device):
x = Tensor(x_p, device=device)
inner = Tensor(x.uop.src[0]) if x.uop.op is Ops.MULTI else x
return (inner.abs().max(),)
def _local_abs_max(x:Tensor) -> Tensor:
param = x.as_param(0)
fxn = _local_abs_max_fxn(param.uop, x.device)
return Tensor(fxn[0].uop.call(x.uop).gettuple(0))
def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
new_amax = (_local_abs_max(x) if isinstance(x.device, tuple) else x.abs().max()).detach()
new_amax = (local_abs_max(x) if isinstance(x.device, tuple) else x.abs().max()).detach()
scale = FP8_MAX / ((amax_state if amax_state is not None else new_amax) + 1e-8)
x_scaled = x * scale
x_clamped = x_scaled + (x_scaled.detach().clamp(-FP8_MAX, FP8_MAX) - x_scaled.detach()) # STE
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal(), new_amax
def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
x_fp8:Tensor|None=None, x_scale:Tensor|None=None, x_new_amax:Tensor|None=None) -> tuple[Tensor,...]:
if not fp8:
if getenv("ASM_GEMM"):
if ASM_GEMM:
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(x, w.T): return (asm_gemm(x, w.T),)
return (x @ w.T,)
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
if x_fp8 is None: x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x)
if getenv("ASM_GEMM"):
if ASM_GEMM:
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(x_fp8, w.T): return asm_gemm(x_fp8, w.T, x_scale=x_scale, w_scale=w_inv_scale), x_new_amax, x_fp8, w
return x_fp8.dot(w.T, dtype=dtypes.float) * x_scale * w_inv_scale, x_new_amax, x_fp8, w
return (x_fp8.dot(w.T, dtype=dtypes.float) * x_scale * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8, w
def _rmsnorm_fwd(x_in:Tensor, eps:float) -> tuple[Tensor, Tensor]:
x = x_in.float()
rrms = (x.square().mean(-1, keepdim=True) + eps).rsqrt()
return (x * rrms).cast(x_in.dtype), rrms
def norm_mul_quantize_matmul(x:Tensor, norm:Tensor, amax_x, w_inv_scale, w:Tensor, eps:float):
FUSED_NORM_MUL_QUANTIZE = getenv("FUSED_NORM_MUL_QUANTIZE", 0)
normed, rrms = rmsnorm(x, eps)
if FUSED_NORM_MUL_QUANTIZE:
from extra.llama_kernels.fused_mul_quantize_fp8 import fused_mul_quantize_fp8
amax_s = amax_x if amax_x is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=normed.device)
x_fp8, x_inv_scale, new_amax = fused_mul_quantize_fp8(normed, norm, amax_s, FP8_DTYPE)
out, *ret = matmul(None, w, w_inv_scale=w_inv_scale, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax)
else:
x = normed * norm
out, *ret = matmul(x, w, amax_x=amax_x, w_inv_scale=w_inv_scale)
return out, normed, rrms, ret
@functools.cache
def _rmsnorm_fwd_fxn(x_in_p, eps, device):
return _rmsnorm_fwd(Tensor(x_in_p, device=device), eps)
def _rmsnorm_bwd(grad:UOp, call:UOp) -> tuple:
x_normed = Tensor(call.gettuple(0)).float()
do_float = Tensor(grad).float()
d_x = Tensor(call.gettuple(1)) * (do_float - x_normed * (do_float * x_normed).mean(-1, keepdim=True))
return (d_x.cast(call.src[1].dtype).uop,)
def rmsnorm(x_in:Tensor, eps:float) -> tuple[Tensor, Tensor]:
fxn = _rmsnorm_fwd_fxn(x_in.as_param(0).uop, eps, x_in.device)
call = UOp.maketuple(fxn[0].uop, fxn[1].uop).call(x_in.uop, grad_fxn=_rmsnorm_bwd)
return Tensor(call.gettuple(0)), Tensor(call.gettuple(1))
def silu_w13_matmul(x_w13:Tensor, w2:Tensor, amax_x2, s_2):
FUSED_SILU_W13 = getenv("FUSED_SILU_W13", 0)
if FUSED_SILU_W13:
from extra.llama_kernels.cast_amax import fused_quantize_fp8_w13
amax_s = amax_x2 if amax_x2 is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x_w13.device)
x2_fp8, x2_inv_scale, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_s, FP8_DTYPE)
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, x_scale=x2_inv_scale, x_new_amax=new_amax_x2)
else:
hidden_dim = x_w13.shape[-1] // 2
x_w1, x_w3 = x_w13[..., :hidden_dim], x_w13[..., hidden_dim:]
out, *ret = matmul(x_w1.silu() * x_w3, w2, amax_x=amax_x2, w_inv_scale=s_2)
return out, ret
class FlatTransformer:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size:int, n_kv_heads:int|None=None,
@ -90,7 +85,7 @@ class FlatTransformer:
scaled_std = 0.02 / math.sqrt(2 * n_layers)
# Attention
self._init_inv_scales = [] # populated by lin_per_layer when FP8
self._init_inv_scales = [] # populated by lin_per_layer
self.wqkv = self.lin_per_layer(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2)
self.wo = self.lin_per_layer(self.n_heads * self.head_dim, dim, std=scaled_std)
@ -109,21 +104,19 @@ class FlatTransformer:
self.output = Tensor.normal(1, vocab_size, dim, mean=0.0, std=0.02, dtype=dtypes.bfloat16)
self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_context * 2, rope_theta).contiguous().requires_grad_(False)
if FP8:
def _amax(): return Tensor.full((), FP8_MAX).contiguous().requires_grad_(False)
names = ["xqkv", "xo", "x13", "x2"]
self._fp8_amax = {name: [_amax() for _ in range(n_layers)] for name in names}
# per-weight inv_scale: single (n_layers,) float32 tensor per weight (kernel reads float* pointers)
w_names = ["wqkv", "wo", "w13", "w2"]
self._fp8_inv_scale = {}
for wname, inv_scales in zip(w_names, self._init_inv_scales):
self._fp8_inv_scale[wname] = inv_scales.float().contiguous().requires_grad_(False)
del self._init_inv_scales
def _amax(): return Tensor.full((), FP8_MAX).contiguous().requires_grad_(False)
names = ["xqkv", "xo", "x13", "x2"]
self._fp8_amax = {name: [_amax() for _ in range(n_layers)] for name in names}
# per-weight inv_scale: single (n_layers,) float32 tensor per weight (kernel reads float* pointers)
w_names = ["wqkv", "wo", "w13", "w2"]
self._fp8_inv_scale = {}
for wname, inv_scales in zip(w_names, self._init_inv_scales):
self._fp8_inv_scale[wname] = inv_scales.float().contiguous().requires_grad_(False)
del self._init_inv_scales
def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02):
if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features)
if getenv("ZEROS", 0): w = Tensor.zeros(self.n_layers, out_features, in_features)
else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std)
if not FP8: return w
# per-layer scaled fp8 cast: fill the fp8 range for best precision
amax = w.abs().flatten(1).max(1).detach()
scale = FP8_MAX / (amax + 1e-8)
@ -135,18 +128,8 @@ class FlatTransformer:
bsz, seqlen, _ = x.shape
new_amaxs, saves = [], []
x, rrms = rmsnorm(x, self.norm_eps)
saves.extend([x, rrms])
if FP8 and getenv("FUSED_NORM_MUL_QUANTIZE", 1):
from extra.amax.cast_amax import fused_mul_quantize_fp8
amax_s = amax_xqkv if amax_xqkv is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x.device)
x_fp8, x_inv_scale, new_amax_xqkv = fused_mul_quantize_fp8(x, attention_norm, amax_s, FP8_DTYPE)
xqkv, *ret = matmul(None, wqkv, w_inv_scale=s_qkv, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax_xqkv)
else:
x = x * attention_norm
xqkv, *ret = matmul(x, wqkv, amax_x=amax_xqkv, w_inv_scale=s_qkv)
xqkv, normed, rrms, ret = norm_mul_quantize_matmul(x, attention_norm, amax_xqkv, s_qkv, wqkv, self.norm_eps)
saves.extend([normed, rrms])
new_amaxs.extend(ret[:1])
saves.extend(ret[1:] + [xqkv])
xqkv = xqkv.reshape(bsz, seqlen, self.n_kv_heads, self.n_rep + 2, self.head_dim)
@ -155,7 +138,7 @@ class FlatTransformer:
xv = xqkv[:, :, :, self.n_rep+1].reshape(bsz, seqlen, self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
if FP8: xq, xk, xv = xq.cast(dtypes.bfloat16), xk.cast(dtypes.bfloat16), xv.cast(dtypes.bfloat16)
xq, xk, xv = xq.cast(dtypes.bfloat16), xk.cast(dtypes.bfloat16), xv.cast(dtypes.bfloat16)
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
if getenv("HK_FLASH_ATTENTION"):
from extra.thunder.amd.fa import flash_attention
@ -174,28 +157,12 @@ class FlatTransformer:
amax_x13=None, amax_x2=None, s_13=None, s_2=None):
new_amaxs, saves = [], []
x, rrms = rmsnorm(x, self.norm_eps)
saves.extend([x, rrms])
if FP8 and getenv("FUSED_NORM_MUL_QUANTIZE", 1):
from extra.amax.cast_amax import fused_mul_quantize_fp8
amax_s13 = amax_x13 if amax_x13 is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x.device)
x_fp8_13, x_inv_scale_13, new_amax_x13 = fused_mul_quantize_fp8(x, ffn_norm, amax_s13, FP8_DTYPE)
x_w13, *ret = matmul(None, w13, w_inv_scale=s_13, x_fp8=x_fp8_13, x_scale=x_inv_scale_13, x_new_amax=new_amax_x13)
else:
x = x * ffn_norm
x_w13, *ret = matmul(x, w13, amax_x=amax_x13, w_inv_scale=s_13)
x_w13, normed, rrms, ret = norm_mul_quantize_matmul(x, ffn_norm, amax_x13, s_13, w13, self.norm_eps)
saves.extend([normed, rrms])
new_amaxs.extend(ret[:1])
saves.extend(ret[1:] + [x_w13])
if FP8 and getenv("FUSED_SILU_W13", 1):
from extra.amax.cast_amax import fused_quantize_fp8_w13
amax_s = amax_x2 if amax_x2 is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x_w13.device)
x2_fp8, x2_inv_scale, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_s, FP8_DTYPE)
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, x_scale=x2_inv_scale, x_new_amax=new_amax_x2)
else:
x_w1, x_w3 = x_w13[..., :self.hidden_dim], x_w13[..., self.hidden_dim:]
out, *ret = matmul(x_w1.silu() * x_w3, w2, amax_x=amax_x2, w_inv_scale=s_2)
out, ret = silu_w13_matmul(x_w13, w2, amax_x2, s_2)
new_amaxs.extend(ret[:1])
saves.extend(ret[1:] + [out])
return (out, *new_amaxs, *saves)
@ -226,41 +193,35 @@ class FlatTransformer:
else:
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
self.wqkv.shard_(device, axis=1).realize() # (n_layers, out, dim) shard out
self.wo.shard_(device, axis=2).realize() # (n_layers, dim, in) shard in
self.w13.shard_(device, axis=1).realize() # (n_layers, hidden*2, dim) shard out
self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in
self.wo.shard_(device, axis=2).realize() # (n_layers, dim, in) shard in
self.w13.shard_(device, axis=1).realize() # (n_layers, hidden*2, dim) shard out
self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in
self.attention_norm.shard_(device, axis=None).realize()
self.ffn_norm.shard_(device, axis=None).realize()
self.norm.weight.shard_(device, axis=None).realize()
self.tok_embeddings.weight.shard_(device, axis=0).realize()
self.output.shard_(device, axis=1).realize()
self.freqs_cis.shard_(device, axis=None).realize()
if FP8:
for name in self._fp8_amax:
for i in range(len(self._fp8_amax[name])):
self._fp8_amax[name][i] = self._fp8_amax[name][i].to(device).contiguous().requires_grad_(False)
for name in self._fp8_inv_scale:
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().requires_grad_(False)
for name in self._fp8_amax:
for i in range(len(self._fp8_amax[name])):
self._fp8_amax[name][i] = self._fp8_amax[name][i].to(device).contiguous().requires_grad_(False)
for name in self._fp8_inv_scale:
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().requires_grad_(False)
def __call__(self, tokens:Tensor):
h = self.tok_embeddings(tokens)
freqs_cis = self.freqs_cis.cast(h.dtype)[:, :tokens.shape[1], :, :, :]
a = self._fp8_amax if FP8 else None
s = self._fp8_inv_scale if FP8 else None
amaxs, inv_scales = self._fp8_amax, self._fp8_inv_scale
for i in range(self.n_layers):
amax_layer = {"amax_xqkv": a["xqkv"][i], "amax_xo": a["xo"][i],
"amax_x13": a["x13"][i], "amax_x2": a["x2"][i]} if a else {}
scale_layer = {"s_qkv": s["wqkv"][i], "s_o": s["wo"][i],
"s_13": s["w13"][i], "s_2": s["w2"][i]} if s else {}
h, *ret = self.run_layer(h, freqs_cis,
self.attention_norm[i], self.wqkv[i], self.wo[i],
self.ffn_norm[i], self.w13[i], self.w2[i],
**amax_layer, **scale_layer)
if a:
amaxs = ret[:5]
amax_names = ["xqkv", "xo", "x13", "x2"]
for name, new_val in zip(amax_names, amaxs):
a[name][i].assign(new_val)
amax_xqkv=amaxs["xqkv"][i], amax_xo=amaxs["xo"][i],
amax_x13=amaxs["x13"][i], amax_x2=amaxs["x2"][i],
s_qkv=inv_scales["wqkv"][i], s_o=inv_scales["wo"][i],
s_13=inv_scales["w13"][i], s_2=inv_scales["w2"][i])
for name, new_val in zip(["xqkv", "xo", "x13", "x2"], ret[:5]):
amaxs[name][i].assign(new_val)
logits = matmul(self.norm(h).contiguous().contiguous_backward(), self.output[0], fp8=False)[0].contiguous_backward()
return logits

View file

@ -1,238 +0,0 @@
from __future__ import annotations
import functools, pathlib
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.renderer import Estimates
from tinygrad.runtime.support.compiler_amd import HIPCCCompiler
FP8_MAX = 448.0
NUM_WG, THREADS_PER_WG = 1024, 256
def _compile(cpp_name:str, n_elems:int, hidden:int):
src = (pathlib.Path(__file__).parent/cpp_name).read_text()
defines = [f"-DN_ELEMS={n_elems}", f"-DHIDDEN={hidden}", f"-DNUM_WG={NUM_WG}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"]
return src, HIPCCCompiler("gfx950", ["-std=c++20", "-ffast-math", *defines]).compile_cached(src)
def _shard_shape(shape:tuple, axis:int, ndev:int) -> list:
s = list(shape)
s[axis] //= ndev
return s
def _scalar_amax(amax_buf:Tensor) -> Tensor:
if isinstance(amax_buf.device, tuple):
from examples.mlperf.models.flat_llama import _local_abs_max
return _local_abs_max(amax_buf).detach()
return amax_buf.max().detach()
# **** fused silu*mul -> fp8 cast + amax (w13 layout) ****
@functools.cache
def _custom_fused_bwd_w13(grad_xw13:UOp, xw13:UOp, grad_x2:UOp, amax_state:UOp, dname:str) -> UOp:
hidden = xw13.shape[2] // 2
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
mem = n_elems * 2 * 5
sink = UOp.sink(grad_xw13.base, xw13.base, grad_x2.base, amax_state.base, threads, workgroups,
arg=KernelInfo(f"fused_silu_mul_bwd_w13_{n_elems}", estimates=Estimates(ops=8*n_elems, mem=mem)))
src, lib = _compile("cast_amax_bwd_w13.cpp", n_elems, hidden)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
@functools.cache
def _custom_fused_cast_amax_w13(fp8_out:UOp, amax_buf:UOp, xw13:UOp, amax_state:UOp, dname:str) -> UOp:
hidden = xw13.shape[2] // 2
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
mem = n_elems * 2 * 2 + n_elems + NUM_WG * 2
sink = UOp.sink(fp8_out.base, amax_buf.base, xw13.base, amax_state.base, threads, workgroups,
arg=KernelInfo(f"fused_silu_mul_cast_amax_w13_{n_elems}", estimates=Estimates(ops=5*n_elems, mem=mem)))
src, lib = _compile("cast_amax_fwd_w13.cpp", n_elems, hidden)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
# NOTE: inputs are (fp8_out, amax_buf, xw13, amax_state); grad for xw13 only
_, _, xw13, amax_state = kernel.src[1:]
device = xw13.device
if isinstance(device, tuple):
axis, ndev = xw13.axis, len(device)
assert axis in (0, 1), f"unsupported sharding axis={axis}"
grad_xw13 = Tensor(Tensor.invalids(*_shard_shape(xw13.shape, axis, ndev), dtype=dtypes.bfloat16,
device=device).uop.multi(axis), device=device)
dname = device[0].split(":")[0]
else:
grad_xw13 = Tensor.invalids(*xw13.shape, dtype=dtypes.bfloat16, device=device)
dname = device.split(":")[0] if isinstance(device, str) else device
grad_x2_t = Tensor(gradient, device=device).cast(dtypes.bfloat16)
fxn = functools.partial(_custom_fused_bwd_w13, dname=dname)
grad_xw13, *_ = Tensor.custom_kernel(grad_xw13, Tensor(xw13, device=device), grad_x2_t,
Tensor(amax_state, device=device), fxn=fxn)
return (None, None, grad_xw13.uop, None)
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[Tensor, Tensor, Tensor]:
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, inv_scale, new_amax)
assert xw13.dtype == dtypes.bfloat16, f"expected bf16, got {xw13.dtype}"
MBS, SEQ, H2 = xw13.shape
assert H2 % 2 == 0, f"w13 last-axis must be even, got {H2}"
HIDDEN = H2 // 2
if isinstance(xw13.device, tuple):
axis, ndev = xw13.uop.axis, len(xw13.device)
assert axis in (0, 1), f"unsupported sharding axis={axis}"
fp8_out = Tensor(Tensor.invalids(*_shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype,
device=xw13.device).uop.multi(axis), device=xw13.device)
amax_buf = Tensor(Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device).uop.multi(0),
device=xw13.device)
dname = xw13.device[0].split(":")[0]
else:
fp8_out = Tensor.invalids(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=xw13.device)
amax_buf = Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device)
dname = xw13.device.split(":")[0] if isinstance(xw13.device, str) else xw13.device
fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname)
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, fxn=fxn,
grad_fxn=_fused_quantize_bwd_w13)
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
return fp8_out, inv_scale, _scalar_amax(amax_buf)
# **** fused (x * weight) -> fp8 cast + amax (norm-mul-quantize) ****
@functools.cache
def _custom_mul_quantize_fp8(fp8_out:UOp, amax_buf:UOp, x:UOp, weight:UOp, amax_state:UOp, dname:str) -> UOp:
MBS, SEQ, HIDDEN = x.shape
n_elems = MBS * SEQ * HIDDEN
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
mem = n_elems * 2 + HIDDEN * 2 + n_elems + NUM_WG * 2
sink = UOp.sink(fp8_out.base, amax_buf.base, x.base, weight.base, amax_state.base, threads, workgroups,
arg=KernelInfo(f"fused_mul_quantize_fp8_{n_elems}_h{HIDDEN}", estimates=Estimates(ops=3*n_elems, mem=mem)))
src, lib = _compile("fused_mul_quantize_fp8.cpp", n_elems, HIDDEN)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
def _fused_mul_quantize_fp8_bwd(gradient:UOp, kernel:UOp):
# NOTE: inputs are (fp8_out, amax_buf, x, weight, amax_state); grads for x and weight
_, _, x_u, weight_u, amax_state_u = kernel.src[1:]
device = x_u.device
grad_t = Tensor(gradient, device=device).cast(dtypes.bfloat16)
x_t, weight_t = Tensor(x_u, device=device), Tensor(weight_u, device=device)
scale = FP8_MAX / (Tensor(amax_state_u, device=device).float() + 1e-8)
grad_scaled = grad_t.float() * scale
# NOTE: grad_x stays bf16 to avoid CSE materializing a (MBS, SEQ, HIDDEN) fp32 intermediate
grad_x = (grad_scaled * weight_t.float()).cast(dtypes.bfloat16)
grad_weight = (grad_scaled * x_t.float()).sum(axis=(0, 1)).cast(dtypes.bfloat16)
return (None, None, grad_x.uop, grad_weight.uop, None)
def fused_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[Tensor, Tensor, Tensor]:
# NOTE: (x * weight) -> fp8 + amax, delayed scaling. Returns (fp8, inv_scale, new_amax)
assert x.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16
assert x.shape[-1] == weight.shape[-1], f"HIDDEN mismatch: x={x.shape}, weight={weight.shape}"
MBS, SEQ, HIDDEN = x.shape
if isinstance(x.device, tuple):
axis, ndev = x.uop.axis, len(x.device)
assert axis in (0, 1), f"unsupported sharding axis={axis}"
fp8_out = Tensor(Tensor.invalids(*_shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype,
device=x.device).uop.multi(axis), device=x.device)
amax_buf = Tensor(Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=x.device).uop.multi(0), device=x.device)
dname = x.device[0].split(":")[0]
else:
fp8_out = Tensor.invalids(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=x.device)
amax_buf = Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=x.device)
dname = x.device.split(":")[0] if isinstance(x.device, str) else x.device
fxn = functools.partial(_custom_mul_quantize_fp8, dname=dname)
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, x, weight, amax_state, fxn=fxn,
grad_fxn=_fused_mul_quantize_fp8_bwd)
new_amax = _scalar_amax(amax_buf)
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
return fp8_out, inv_scale, new_amax
# **** fused ce loss ****
@functools.cache
def _custom_fused_ce_loss_fwd(loss_out:UOp, max_out:UOp, lse_out:UOp, logits:UOp, targets:UOp,
dname:str, vocab:int, rows:int, label_smoothing:float) -> UOp:
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(rows, "gidx0")
mem = rows * vocab * 2 + rows * 12 + rows * 4
sink = UOp.sink(loss_out.base, max_out.base, lse_out.base, logits.base, targets.base,
threads, workgroups,
arg=KernelInfo(f"fused_ce_loss_fwd", estimates=Estimates(ops=6*rows*vocab, mem=mem)))
src = (pathlib.Path(__file__).parent/"fused_ce_loss.cpp").read_text()
defines = [f"-DVOCAB={vocab}", f"-DTHREADS_PER_WG={THREADS_PER_WG}",
f"-DLABEL_SMOOTHING={label_smoothing}f"]
lib = HIPCCCompiler("gfx950", ["-std=c++20", "-ffast-math", *defines]).compile_cached(src)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
@functools.cache
def _custom_fused_ce_loss_bwd(d_logits:UOp, logits:UOp, lse:UOp, targets:UOp, scale:UOp,
dname:str, vocab:int, rows:int, label_smoothing:float) -> UOp:
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(rows, "gidx0")
mem = rows * vocab * 4 + rows * 8 + 4
sink = UOp.sink(d_logits.base, logits.base, lse.base, targets.base, scale.base,
threads, workgroups,
arg=KernelInfo(f"fused_ce_loss_bwd", estimates=Estimates(ops=4*rows*vocab, mem=mem)))
src = (pathlib.Path(__file__).parent/"fused_ce_loss_bwd.cpp").read_text()
defines = [f"-DVOCAB={vocab}", f"-DTHREADS_PER_WG={THREADS_PER_WG}",
f"-DLABEL_SMOOTHING={label_smoothing}f"]
lib = HIPCCCompiler("gfx950", ["-std=c++20", "-ffast-math", *defines]).compile_cached(src)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
def _fused_ce_loss_bwd(gradient:UOp, kernel:UOp):
# NOTE: forward inputs are (loss_out, max_out, lse_out, logits, targets)
# gradient is the upstream grad w.r.t. per-row loss (shape: (rows,) fp32)
_, _, lse_u, logits_u, targets_u = kernel.src[1:]
device = logits_u.device
rows_vocab = logits_u.shape # (rows, VOCAB) after reshape
rows, VOCAB = rows_vocab
if isinstance(device, tuple):
axis = logits_u.axis
ndev = len(device)
d_logits = Tensor(Tensor.invalids(rows // ndev, VOCAB, dtype=dtypes.bfloat16, device=device).uop.multi(axis), device=device)
dname = device[0].split(":")[0]
rows_per_dev = rows // ndev
else:
d_logits = Tensor.invalids(rows, VOCAB, dtype=dtypes.bfloat16, device=device)
dname = device.split(":")[0] if isinstance(device, str) else device
rows_per_dev = rows
grad_t = Tensor(gradient, device=device).float().reshape(-1) # (rows,) fp32
# NOTE: .mean() backward gives same grad per row (1/N), so broadcast is safe; take scalar
scale = grad_t[0:1].contiguous()
logits_t = Tensor(logits_u.after(kernel), device=device)
lse_t = Tensor(lse_u.after(kernel), device=device)
targets_t = Tensor(targets_u, device=device)
fxn = functools.partial(_custom_fused_ce_loss_bwd, dname=dname, vocab=VOCAB, rows=rows_per_dev, label_smoothing=0.1)
d_logits, *_ = Tensor.custom_kernel(d_logits, logits_t, lse_t, targets_t, scale, fxn=fxn)
return (None, None, None, d_logits.uop, None)
def fused_ce_loss(logits:Tensor, targets:Tensor, label_smoothing:float=0.1) -> Tensor:
# NOTE: fused sparse_categorical_crossentropy with label smoothing, returns mean loss scalar
assert logits.dtype == dtypes.bfloat16, f"expected bf16, got {logits.dtype}"
assert logits.ndim == 3, f"expected (MBS, SEQ, VOCAB), got {logits.shape}"
MBS, SEQ, VOCAB = logits.shape
rows = MBS * SEQ
if isinstance(logits.device, tuple):
axis = logits.uop.axis
assert axis in (0, 1), f"unsupported sharding axis={axis} for CE loss"
ndev = len(logits.device)
loss_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
device=logits.device)
max_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
device=logits.device)
lse_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
device=logits.device)
dname = logits.device[0].split(":")[0]
rows_per_dev = rows // ndev
else:
loss_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
max_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
lse_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
dname = logits.device.split(":")[0] if isinstance(logits.device, str) else logits.device
rows_per_dev = rows
logits_flat = logits.reshape(rows, VOCAB)
targets_flat = targets.reshape(-1).cast(dtypes.int32)
fxn = functools.partial(_custom_fused_ce_loss_fwd, dname=dname, vocab=VOCAB, rows=rows_per_dev,
label_smoothing=label_smoothing)
loss_out, max_out, lse_out, *_ = Tensor.custom_kernel(
loss_out, max_out, lse_out, logits_flat, targets_flat,
fxn=fxn, grad_fxn=_fused_ce_loss_bwd)
return loss_out.mean()

View file

@ -0,0 +1,35 @@
from __future__ import annotations
import functools, pathlib
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import Ops
from tinygrad.runtime.support.compiler_amd import HIPCCCompiler
FP8_MAX = 448.0
NUM_WG, THREADS_PER_WG = 1024, 256
# per-device abs max without allreduce
@functools.cache
def _local_abs_max_fxn(x_p, device):
x = Tensor(x_p, device=device)
inner = Tensor(x.uop.src[0]) if x.uop.op is Ops.MULTI else x
return (inner.abs().max(),)
def local_abs_max(x:Tensor) -> Tensor:
param = x.as_param(0)
fxn = _local_abs_max_fxn(param.uop, x.device)
return Tensor(fxn[0].uop.call(x.uop).gettuple(0))
def scalar_amax(amax_buf:Tensor) -> Tensor:
if isinstance(amax_buf.device, tuple):
return local_abs_max(amax_buf).detach()
return amax_buf.max().detach()
def shard_shape(shape:tuple, axis:int, ndev:int) -> list:
s = list(shape)
s[axis] //= ndev
return s
def compile_cpp(cpp_dir:pathlib.Path, cpp_name:str, n_elems:int, hidden:int):
src = (cpp_dir/cpp_name).read_text()
defines = [f"-DN_ELEMS={n_elems}", f"-DHIDDEN={hidden}", f"-DNUM_WG={NUM_WG}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"]
return src, HIPCCCompiler("gfx950", ["-std=c++20", "-ffast-math", *defines]).compile_cached(src)

View file

@ -0,0 +1,73 @@
from __future__ import annotations
import functools, pathlib
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.renderer import Estimates
from extra.llama_kernels import FP8_MAX, NUM_WG, THREADS_PER_WG, compile_cpp, shard_shape, scalar_amax
@functools.cache
def _custom_fused_bwd_w13(grad_xw13:UOp, xw13:UOp, grad_x2:UOp, amax_state:UOp, dname:str) -> UOp:
hidden = xw13.shape[2] // 2
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
mem = n_elems * 2 * 5
sink = UOp.sink(grad_xw13.base, xw13.base, grad_x2.base, amax_state.base, threads, workgroups,
arg=KernelInfo(f"fused_silu_mul_bwd_w13_{n_elems}", estimates=Estimates(ops=8*n_elems, mem=mem)))
src, lib = compile_cpp(pathlib.Path(__file__).parent, "cast_amax_bwd_w13.cpp", n_elems, hidden)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
@functools.cache
def _custom_fused_cast_amax_w13(fp8_out:UOp, amax_buf:UOp, xw13:UOp, amax_state:UOp, dname:str) -> UOp:
hidden = xw13.shape[2] // 2
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
mem = n_elems * 2 * 2 + n_elems + NUM_WG * 2
sink = UOp.sink(fp8_out.base, amax_buf.base, xw13.base, amax_state.base, threads, workgroups,
arg=KernelInfo(f"fused_silu_mul_cast_amax_w13_{n_elems}", estimates=Estimates(ops=5*n_elems, mem=mem)))
src, lib = compile_cpp(pathlib.Path(__file__).parent, "cast_amax_fwd_w13.cpp", n_elems, hidden)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
# NOTE: inputs are (fp8_out, amax_buf, xw13, amax_state); grad for xw13 only
_, _, xw13, amax_state = kernel.src[1:]
device = xw13.device
if isinstance(device, tuple):
axis, ndev = xw13.axis, len(device)
assert axis in (0, 1), f"unsupported sharding axis={axis}"
grad_xw13 = Tensor(Tensor.invalids(*shard_shape(xw13.shape, axis, ndev), dtype=dtypes.bfloat16,
device=device).uop.multi(axis), device=device)
dname = device[0].split(":")[0]
else:
grad_xw13 = Tensor.invalids(*xw13.shape, dtype=dtypes.bfloat16, device=device)
dname = device.split(":")[0] if isinstance(device, str) else device
grad_x2_t = Tensor(gradient, device=device).cast(dtypes.bfloat16)
fxn = functools.partial(_custom_fused_bwd_w13, dname=dname)
grad_xw13, *_ = Tensor.custom_kernel(grad_xw13, Tensor(xw13, device=device), grad_x2_t,
Tensor(amax_state, device=device), fxn=fxn)
return (None, None, grad_xw13.uop, None)
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[Tensor, Tensor, Tensor]:
# NOTE: silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, inv_scale, new_amax)
assert xw13.dtype == dtypes.bfloat16, f"expected bf16, got {xw13.dtype}"
MBS, SEQ, H2 = xw13.shape
assert H2 % 2 == 0, f"w13 last-axis must be even, got {H2}"
HIDDEN = H2 // 2
if isinstance(xw13.device, tuple):
axis, ndev = xw13.uop.axis, len(xw13.device)
assert axis in (0, 1), f"unsupported sharding axis={axis}"
fp8_out = Tensor(Tensor.invalids(*shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype,
device=xw13.device).uop.multi(axis), device=xw13.device)
amax_buf = Tensor(Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device).uop.multi(0),
device=xw13.device)
dname = xw13.device[0].split(":")[0]
else:
fp8_out = Tensor.invalids(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=xw13.device)
amax_buf = Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device)
dname = xw13.device.split(":")[0] if isinstance(xw13.device, str) else xw13.device
fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname)
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, fxn=fxn,
grad_fxn=_fused_quantize_bwd_w13)
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
return fp8_out, inv_scale, scalar_amax(amax_buf)

View file

@ -0,0 +1,98 @@
from __future__ import annotations
import functools, pathlib
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.renderer import Estimates
from tinygrad.runtime.support.compiler_amd import HIPCCCompiler
THREADS_PER_WG = 256
@functools.cache
def _custom_fused_ce_loss_fwd(loss_out:UOp, max_out:UOp, lse_out:UOp, logits:UOp, targets:UOp,
dname:str, vocab:int, rows:int, label_smoothing:float) -> UOp:
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(rows, "gidx0")
mem = rows * vocab * 2 + rows * 12 + rows * 4
sink = UOp.sink(loss_out.base, max_out.base, lse_out.base, logits.base, targets.base,
threads, workgroups,
arg=KernelInfo(f"fused_ce_loss_fwd", estimates=Estimates(ops=6*rows*vocab, mem=mem)))
src = (pathlib.Path(__file__).parent/"fused_ce_loss.cpp").read_text()
defines = [f"-DVOCAB={vocab}", f"-DTHREADS_PER_WG={THREADS_PER_WG}",
f"-DLABEL_SMOOTHING={label_smoothing}f"]
lib = HIPCCCompiler("gfx950", ["-std=c++20", "-ffast-math", *defines]).compile_cached(src)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
@functools.cache
def _custom_fused_ce_loss_bwd(d_logits:UOp, logits:UOp, lse:UOp, targets:UOp, scale:UOp,
dname:str, vocab:int, rows:int, label_smoothing:float) -> UOp:
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(rows, "gidx0")
mem = rows * vocab * 4 + rows * 8 + 4
sink = UOp.sink(d_logits.base, logits.base, lse.base, targets.base, scale.base,
threads, workgroups,
arg=KernelInfo(f"fused_ce_loss_bwd", estimates=Estimates(ops=4*rows*vocab, mem=mem)))
src = (pathlib.Path(__file__).parent/"fused_ce_loss_bwd.cpp").read_text()
defines = [f"-DVOCAB={vocab}", f"-DTHREADS_PER_WG={THREADS_PER_WG}",
f"-DLABEL_SMOOTHING={label_smoothing}f"]
lib = HIPCCCompiler("gfx950", ["-std=c++20", "-ffast-math", *defines]).compile_cached(src)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
def _fused_ce_loss_bwd(gradient:UOp, kernel:UOp):
# NOTE: forward inputs are (loss_out, max_out, lse_out, logits, targets)
# gradient is the upstream grad w.r.t. per-row loss (shape: (rows,) fp32)
_, _, lse_u, logits_u, targets_u = kernel.src[1:]
device = logits_u.device
rows_vocab = logits_u.shape # (rows, VOCAB) after reshape
rows, VOCAB = rows_vocab
if isinstance(device, tuple):
axis = logits_u.axis
ndev = len(device)
d_logits = Tensor(Tensor.invalids(rows // ndev, VOCAB, dtype=dtypes.bfloat16, device=device).uop.multi(axis), device=device)
dname = device[0].split(":")[0]
rows_per_dev = rows // ndev
else:
d_logits = Tensor.invalids(rows, VOCAB, dtype=dtypes.bfloat16, device=device)
dname = device.split(":")[0] if isinstance(device, str) else device
rows_per_dev = rows
grad_t = Tensor(gradient, device=device).float().reshape(-1) # (rows,) fp32
# NOTE: .mean() backward gives same grad per row (1/N), so broadcast is safe; take scalar
scale = grad_t[0:1].contiguous()
logits_t = Tensor(logits_u.after(kernel), device=device)
lse_t = Tensor(lse_u.after(kernel), device=device)
targets_t = Tensor(targets_u, device=device)
fxn = functools.partial(_custom_fused_ce_loss_bwd, dname=dname, vocab=VOCAB, rows=rows_per_dev, label_smoothing=label_smoothing)
d_logits, *_ = Tensor.custom_kernel(d_logits, logits_t, lse_t, targets_t, scale, fxn=fxn)
return (None, None, None, d_logits.uop, None)
def fused_ce_loss(logits:Tensor, targets:Tensor, label_smoothing:float=0.1) -> Tensor:
# NOTE: fused sparse_categorical_crossentropy with label smoothing, returns mean loss scalar
assert logits.dtype == dtypes.bfloat16, f"expected bf16, got {logits.dtype}"
assert logits.ndim == 3, f"expected (MBS, SEQ, VOCAB), got {logits.shape}"
MBS, SEQ, VOCAB = logits.shape
rows = MBS * SEQ
if isinstance(logits.device, tuple):
axis = logits.uop.axis
assert axis in (0, 1), f"unsupported sharding axis={axis} for CE loss"
ndev = len(logits.device)
loss_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
device=logits.device)
max_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
device=logits.device)
lse_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
device=logits.device)
dname = logits.device[0].split(":")[0]
rows_per_dev = rows // ndev
else:
loss_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
max_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
lse_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
dname = logits.device.split(":")[0] if isinstance(logits.device, str) else logits.device
rows_per_dev = rows
logits_flat = logits.reshape(rows, VOCAB)
targets_flat = targets.reshape(-1).cast(dtypes.int32)
fxn = functools.partial(_custom_fused_ce_loss_fwd, dname=dname, vocab=VOCAB, rows=rows_per_dev,
label_smoothing=label_smoothing)
loss_out, max_out, lse_out, *_ = Tensor.custom_kernel(
loss_out, max_out, lse_out, logits_flat, targets_flat,
fxn=fxn, grad_fxn=_fused_ce_loss_bwd)
return loss_out.mean()

View file

@ -0,0 +1,54 @@
from __future__ import annotations
import functools, pathlib
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.renderer import Estimates
from extra.llama_kernels import FP8_MAX, NUM_WG, THREADS_PER_WG, compile_cpp, shard_shape, scalar_amax
@functools.cache
def _custom_mul_quantize_fp8(fp8_out:UOp, amax_buf:UOp, x:UOp, weight:UOp, amax_state:UOp, dname:str) -> UOp:
MBS, SEQ, HIDDEN = x.shape
n_elems = MBS * SEQ * HIDDEN
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
mem = n_elems * 2 + HIDDEN * 2 + n_elems + NUM_WG * 2
sink = UOp.sink(fp8_out.base, amax_buf.base, x.base, weight.base, amax_state.base, threads, workgroups,
arg=KernelInfo(f"fused_mul_quantize_fp8_{n_elems}_h{HIDDEN}", estimates=Estimates(ops=3*n_elems, mem=mem)))
src, lib = compile_cpp(pathlib.Path(__file__).parent, "fused_mul_quantize_fp8.cpp", n_elems, HIDDEN)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
def _fused_mul_quantize_fp8_bwd(gradient:UOp, kernel:UOp):
# NOTE: inputs are (fp8_out, amax_buf, x, weight, amax_state); grads for x and weight
_, _, x_u, weight_u, amax_state_u = kernel.src[1:]
device = x_u.device
grad_t = Tensor(gradient, device=device).cast(dtypes.bfloat16)
x_t, weight_t = Tensor(x_u, device=device), Tensor(weight_u, device=device)
scale = FP8_MAX / (Tensor(amax_state_u, device=device).float() + 1e-8)
grad_scaled = grad_t.float() * scale
# NOTE: grad_x stays bf16 to avoid CSE materializing a (MBS, SEQ, HIDDEN) fp32 intermediate
grad_x = (grad_scaled * weight_t.float()).cast(dtypes.bfloat16)
grad_weight = (grad_scaled * x_t.float()).sum(axis=(0, 1)).cast(dtypes.bfloat16)
return (None, None, grad_x.uop, grad_weight.uop, None)
def fused_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[Tensor, Tensor, Tensor]:
# NOTE: (x * weight) -> fp8 + amax, delayed scaling. Returns (fp8, inv_scale, new_amax)
assert x.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16
assert x.shape[-1] == weight.shape[-1], f"HIDDEN mismatch: x={x.shape}, weight={weight.shape}"
MBS, SEQ, HIDDEN = x.shape
if isinstance(x.device, tuple):
axis, ndev = x.uop.axis, len(x.device)
assert axis in (0, 1), f"unsupported sharding axis={axis}"
fp8_out = Tensor(Tensor.invalids(*shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype,
device=x.device).uop.multi(axis), device=x.device)
amax_buf = Tensor(Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=x.device).uop.multi(0), device=x.device)
dname = x.device[0].split(":")[0]
else:
fp8_out = Tensor.invalids(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=x.device)
amax_buf = Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=x.device)
dname = x.device.split(":")[0] if isinstance(x.device, str) else x.device
fxn = functools.partial(_custom_mul_quantize_fp8, dname=dname)
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, x, weight, amax_state, fxn=fxn,
grad_fxn=_fused_mul_quantize_fp8_bwd)
new_amax = scalar_amax(amax_buf)
inv_scale = (amax_state.float() + 1e-8) / FP8_MAX
return fp8_out, inv_scale, new_amax

View file

@ -0,0 +1,24 @@
from __future__ import annotations
import functools
from tinygrad import Tensor
from tinygrad.uop.ops import UOp
def rmsnorm_fwd(x_in:Tensor, eps:float) -> tuple[Tensor, Tensor]:
x = x_in.float()
rrms = (x.square().mean(-1, keepdim=True) + eps).rsqrt()
return (x * rrms).cast(x_in.dtype), rrms
@functools.cache
def _rmsnorm_fwd_fxn(x_in_p, eps, device):
return rmsnorm_fwd(Tensor(x_in_p, device=device), eps)
def _rmsnorm_bwd(grad:UOp, call:UOp) -> tuple:
x_normed = Tensor(call.gettuple(0)).float()
do_float = Tensor(grad).float()
d_x = Tensor(call.gettuple(1)) * (do_float - x_normed * (do_float * x_normed).mean(-1, keepdim=True))
return (d_x.cast(call.src[1].dtype).uop,)
def rmsnorm(x_in:Tensor, eps:float) -> tuple[Tensor, Tensor]:
fxn = _rmsnorm_fwd_fxn(x_in.as_param(0).uop, eps, x_in.device)
call = UOp.maketuple(fxn[0].uop, fxn[1].uop).call(x_in.uop, grad_fxn=_rmsnorm_bwd)
return Tensor(call.gettuple(0)), Tensor(call.gettuple(1))