mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove (some) kernelize from llama and test schedule speed (#10939)
* remove kernelize from llama * 405B * space
This commit is contained in:
parent
3699d1d3ba
commit
e15754db28
2 changed files with 45 additions and 7 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Union, Optional, Any
|
||||
import collections
|
||||
import collections, math
|
||||
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
|
||||
|
|
@ -166,27 +166,29 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
|
|||
|
||||
class Transformer:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, embedding=nn.Embedding,
|
||||
n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward, qk_norm=None):
|
||||
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward, qk_norm=qk_norm) for _ in range(n_layers)]
|
||||
n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward, qk_norm=None, disable_kv_cache=False):
|
||||
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, 0 if disable_kv_cache else max_context,
|
||||
linear, feed_forward=feed_forward, qk_norm=qk_norm) for _ in range(n_layers)]
|
||||
self.norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = embedding(vocab_size, dim)
|
||||
self.output = nn.Linear(dim, vocab_size, bias=False) if embedding == nn.Embedding else linear(dim, vocab_size, bias=False)
|
||||
self.max_context = max_context
|
||||
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous()
|
||||
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous().requires_grad_(False)
|
||||
self.forward_jit = TinyJit(self.forward) if jit else None
|
||||
|
||||
def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
|
||||
self.freqs_cis = self.freqs_cis.cast(h.dtype).kernelize()
|
||||
self.freqs_cis = self.freqs_cis.cast(h.dtype).contiguous()
|
||||
freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen, :, :, :]
|
||||
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).kernelize() if seqlen > 1 else None
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1) if seqlen > 1 else None
|
||||
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
|
||||
logits = self.output(self.norm(h)).float()[:, -1, :]
|
||||
if math.isnan(temperature): return logits
|
||||
|
||||
return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).kernelize()
|
||||
return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
|
||||
# TODO: better way to handle the first call v.s. the rest?
|
||||
|
|
|
|||
36
test/external/external_benchmark_llama_schedule.py
vendored
Normal file
36
test/external/external_benchmark_llama_schedule.py
vendored
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
from tinygrad import nn, Tensor, Device, dtypes
|
||||
from tinygrad.helpers import Timing
|
||||
|
||||
from extra.models.llama import Transformer
|
||||
from examples.llama3 import MODEL_PARAMS
|
||||
|
||||
if __name__ == "__main__":
|
||||
Device.DEFAULT = "NULL"
|
||||
Tensor.training = True
|
||||
#model_size = "8B"
|
||||
model_size = "405B"
|
||||
|
||||
with Timing("total "):
|
||||
with Timing("***** create model in "):
|
||||
# NOTE: max_context=None means no kv cache. kv cache has realize in the model
|
||||
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=nn.Linear, embedding=nn.Embedding,
|
||||
max_context=1024, jit=True, disable_kv_cache=True)
|
||||
|
||||
with Timing("***** fake state in "):
|
||||
Tensor.realize(*[p.assign(Tensor.empty(*p.shape, device=p.device, dtype=p.dtype)) for p in nn.state.get_parameters(model)])
|
||||
|
||||
with Timing("***** create optim in "):
|
||||
opt = nn.optim.AdamW(nn.state.get_parameters(model))
|
||||
|
||||
with Timing("***** run model in "):
|
||||
toks = Tensor.empty(1, 1024, dtype=dtypes.int)
|
||||
out = model(toks, 0, temperature=float('nan'))
|
||||
|
||||
with Timing("***** backward in "):
|
||||
out.mean().backward()
|
||||
|
||||
with Timing("***** realize in "):
|
||||
out.realize()
|
||||
|
||||
with Timing("***** step in "):
|
||||
opt.step()
|
||||
Loading…
Add table
Add a link
Reference in a new issue