remove (some) kernelize from llama and test schedule speed (#10939)

* remove kernelize from llama

* 405B

* space
This commit is contained in:
George Hotz 2025-06-23 15:07:31 -07:00 committed by GitHub
commit e15754db28
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 45 additions and 7 deletions

View file

@ -1,5 +1,5 @@
from typing import Union, Optional, Any
import collections
import collections, math
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
from tinygrad.helpers import getenv, DEBUG
@ -166,27 +166,29 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
class Transformer:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, embedding=nn.Embedding,
n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward, qk_norm=None):
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward, qk_norm=qk_norm) for _ in range(n_layers)]
n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward, qk_norm=None, disable_kv_cache=False):
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, 0 if disable_kv_cache else max_context,
linear, feed_forward=feed_forward, qk_norm=qk_norm) for _ in range(n_layers)]
self.norm = nn.RMSNorm(dim, norm_eps)
self.tok_embeddings = embedding(vocab_size, dim)
self.output = nn.Linear(dim, vocab_size, bias=False) if embedding == nn.Embedding else linear(dim, vocab_size, bias=False)
self.max_context = max_context
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous()
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous().requires_grad_(False)
self.forward_jit = TinyJit(self.forward) if jit else None
def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.cast(h.dtype).kernelize()
self.freqs_cis = self.freqs_cis.cast(h.dtype).contiguous()
freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen, :, :, :]
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).kernelize() if seqlen > 1 else None
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1) if seqlen > 1 else None
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
logits = self.output(self.norm(h)).float()[:, -1, :]
if math.isnan(temperature): return logits
return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).kernelize()
return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p)
def __call__(self, tokens:Tensor, start_pos:int, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
# TODO: better way to handle the first call v.s. the rest?

View file

@ -0,0 +1,36 @@
from tinygrad import nn, Tensor, Device, dtypes
from tinygrad.helpers import Timing
from extra.models.llama import Transformer
from examples.llama3 import MODEL_PARAMS
if __name__ == "__main__":
Device.DEFAULT = "NULL"
Tensor.training = True
#model_size = "8B"
model_size = "405B"
with Timing("total "):
with Timing("***** create model in "):
# NOTE: max_context=None means no kv cache. kv cache has realize in the model
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=nn.Linear, embedding=nn.Embedding,
max_context=1024, jit=True, disable_kv_cache=True)
with Timing("***** fake state in "):
Tensor.realize(*[p.assign(Tensor.empty(*p.shape, device=p.device, dtype=p.dtype)) for p in nn.state.get_parameters(model)])
with Timing("***** create optim in "):
opt = nn.optim.AdamW(nn.state.get_parameters(model))
with Timing("***** run model in "):
toks = Tensor.empty(1, 1024, dtype=dtypes.int)
out = model(toks, 0, temperature=float('nan'))
with Timing("***** backward in "):
out.mean().backward()
with Timing("***** realize in "):
out.realize()
with Timing("***** step in "):
opt.step()