refactor llm into files (#15780)

* refactor llm into files

* chat.html

* tokenizer cleanup

* cleanup

* tests
This commit is contained in:
George Hotz 2026-04-17 12:33:11 +08:00 committed by GitHub
commit a9b6cfece0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 493 additions and 501 deletions

View file

@ -1,6 +1,6 @@
import unittest
from tinygrad import Tensor, dtypes, TinyJit, UOp
from tinygrad.llm.cli import apply_rope as apply_rope_new, precompute_freqs_cis
from tinygrad.llm.model import apply_rope as apply_rope_new, precompute_freqs_cis
from test.helpers import assert_jit_cache_len
def apply_rope(x:Tensor, start_pos:int):

View file

@ -14,23 +14,18 @@ class TestLLMServer(unittest.TestCase):
cls.mock_tok.end_turn = Mock(return_value=[998])
cls.mock_tok.prefix = Mock(return_value=[1])
cls.mock_tok.preset = "llama3"
cls.mock_tok.bos_id = 1
cls.mock_tok.eos_id = 999
cls.mock_tok.eot_id = None
cls.mock_tok.is_end = Mock(side_effect=lambda tid: tid in (999,))
cls.mock_model = Mock()
cls.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 999]))
cls.mock_model.get_start_pos = Mock(return_value=0)
cls.bos_id = 1
cls.eos_id = 999
from tinygrad.llm.cli import LLMServer
from tinygrad.llm.cli import Handler, LLMServer
cls.server = LLMServer(('127.0.0.1', 0), Handler)
cls.server.model = cls.mock_model
cls.server.model_name = "test-model"
cls.server.tok = cls.mock_tok
cls.server.bos_id = cls.bos_id
cls.server.eos_id = cls.eos_id
cls.server.eot_id = None
cls.server = LLMServer(('127.0.0.1', 0), cls.mock_model, "test-model", cls.mock_tok)
cls.port = cls.server.server_address[1]
cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True)
cls.server_thread.start()

View file

@ -51,11 +51,12 @@ class TestLLMTokenizer(unittest.TestCase):
"tokenizer.ggml.tokens": ["<unk>", "<s>", "</s>", "[INST]", "[/INST]", "hello"],
"tokenizer.ggml.token_type": [3, 3, 3, 3, 3, 1],
"tokenizer.ggml.pre": "tekken",
"tokenizer.ggml.eos_token_id": 2,
}
tok = SimpleTokenizer.from_gguf_kv(kv)
self.assertEqual(tok.role("user"), [3])
self.assertEqual(tok.encode("hello"), [5])
self.assertEqual(tok.end_turn(2), [4])
self.assertEqual(tok.end_turn(), [4])
self.assertEqual(tok.role("assistant"), [])
def test_stream_decoder(self):

View file

@ -1,7 +1,7 @@
import unittest
import numpy as np
from tinygrad import Tensor, dtypes
from tinygrad.llm.cli import (
from tinygrad.llm.model import (
GatedDeltaNetBlock, SSMConfig, TransformerBlock, TransformerConfig,
apply_rope as apply_rope_new, precompute_freqs_cis, pairwise_topk,
)

View file

@ -1,7 +1,7 @@
import unittest
import numpy as np
from tinygrad import Tensor
from tinygrad.llm.cli import Transformer, TransformerConfig, apply_rope
from tinygrad.llm.model import Transformer, TransformerConfig, apply_rope, MLATransformerBlock, precompute_freqs_cis
class TestMLA(unittest.TestCase):
def _make_config(self, **kwargs):
@ -13,7 +13,6 @@ class TestMLA(unittest.TestCase):
def test_mla_attention_matches_naive(self):
config = self._make_config(max_context=16)
from tinygrad.llm.cli import MLATransformerBlock, precompute_freqs_cis
block = MLATransformerBlock(config)
c = config

View file

@ -2,7 +2,7 @@ import unittest
import numpy as np
from dataclasses import replace
from tinygrad import Tensor
from tinygrad.llm.cli import TransformerBlock, TransformerConfig
from tinygrad.llm.model import TransformerBlock, TransformerConfig
def _moe_config(dim=8, hidden=16, n_heads=2, num_experts=4, num_experts_per_tok=2):
return TransformerConfig(

View file

@ -2,7 +2,7 @@ import unittest
from unittest.mock import patch
from tinygrad import Tensor, UOp
from tinygrad.schedule import schedule_cache
from tinygrad.llm.cli import Transformer, TransformerConfig
from tinygrad.llm.model import Transformer, TransformerConfig
TEST_CONFIG = TransformerConfig(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, rope_dim=32, v_head_dim=32, max_context=32)

38
tinygrad/llm/chat.html Normal file
View file

@ -0,0 +1,38 @@
<!DOCTYPE html><html><head><title>tinygrad chat</title><style>
* { margin: 0 }
body { background: #212121; color: #e3e3e3; font-family: system-ui;
height: 100vh; display: flex; flex-direction: column }
#chat { flex: 1; overflow-y: auto; padding: 20px }
.msg { padding: 10px 16px; margin: 8px 0; white-space: pre-wrap; border-radius: 18px }
.user { background: #2f2f2f; margin-left: auto; width: fit-content; max-width: 70% }
#input { max-width: 768px; width: 100%; margin: 20px auto; padding: 14px 20px;
background: #2f2f2f; color: inherit; font: inherit;
border: none; outline: none; resize: none; border-radius: 24px; field-sizing: content }
</style></head><body><div id="chat"></div>
<textarea id="input" rows="1" placeholder="Ask anything" autofocus></textarea>
<script>
input.onkeydown = (e) => { if (e.key === 'Enter' && !e.shiftKey && !e.isComposing) { e.preventDefault(); send() } }
const msgs = [];
async function send() {
if (!input.value.trim()) return;
msgs.push({role: 'user', content: input.value.trim()});
chat.innerHTML += '<div class="msg user">' + input.value.trim().replace(/</g, '&lt;') + '</div>';
input.value = '';
const d = document.createElement('div'); d.className = 'msg'; chat.appendChild(d);
const r = await fetch('/v1/chat/completions', {method: 'POST', headers: {'Content-Type': 'application/json'},
body: JSON.stringify({model: 'llama', messages: msgs, stream: true, temperature: 0.7})});
let buf = '';
for (const rd = r.body.getReader(), dec = new TextDecoder();;) {
const {done, value} = await rd.read();
if (done) break;
buf += dec.decode(value, {stream: true});
const lines = buf.split('\n');
buf = lines.pop();
for (const ln of lines)
if (ln.startsWith('data: ') && !ln.includes('[DONE]'))
try { d.textContent += JSON.parse(ln.slice(6)).choices[0]?.delta?.content || '' } catch {}
chat.scrollTop = chat.scrollHeight;
}
msgs.push({role: 'assistant', content: d.textContent});
}
</script></body></html>

View file

@ -1,13 +1,13 @@
from __future__ import annotations
import sys, argparse, codecs, typing, re, unicodedata, json, uuid, time, functools, itertools
from dataclasses import dataclass, replace
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function
from tinygrad.uop.ops import resolve
import sys, argparse, codecs, typing, re, unicodedata, json, uuid, time, pathlib
from tinygrad import Tensor, nn
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored, Context
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
from tinygrad.llm.model import Transformer
class SimpleTokenizer:
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"):
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3",
bos_id:int|None=None, eos_id:int=0, eot_id:int|None=None):
preset = {"qwen35":"qwen2","qwen35moe":"qwen2"}.get(preset, preset)
if preset not in ("llama3","llama-v3","llama-bpe","qwen2","olmo","kimi-k2","tekken","glm4"):
raise ValueError(f"Invalid tokenizer preset '{preset}'")
@ -27,13 +27,16 @@ class SimpleTokenizer:
self._special_tokens = special_tokens
self._tok2bytes = {tid: tok for tok, tid in self._normal_tokens.items()} | {tid: tok.encode() for tok, tid in self._special_tokens.items()}
self.preset = preset
self.bos_id, self.eos_id, self.eot_id = bos_id, eos_id, eot_id
@staticmethod
def from_gguf_kv(kv:dict):
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
normal_tokens, special_tokens = partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens), kv["tokenizer.ggml.pre"])
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens), kv["tokenizer.ggml.pre"],
bos_id=kv.get('tokenizer.ggml.bos_token_id') if kv.get('tokenizer.ggml.add_bos_token', True) else None,
eos_id=kv.get('tokenizer.ggml.eos_token_id', 0), eot_id=kv.get('tokenizer.ggml.eot_token_id'))
def _encode_word(self, word:bytes) -> list[int]:
if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token]
@ -70,419 +73,16 @@ class SimpleTokenizer:
if role == 'assistant': return []
raise ValueError(f"Unsupported role '{role}' for tokenizer preset '{self.preset}'")
return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
def end_turn(self, eos_id:int):
def end_turn(self):
if self.preset == 'olmo': return self.encode("\n")
if self.preset == 'kimi-k2': return [eos_id]
if self.preset == 'qwen2': return [eos_id] + self.encode("\n")
if self.preset == 'kimi-k2': return [self.eos_id]
if self.preset == 'qwen2': return [self.eos_id] + self.encode("\n")
if self.preset == 'glm4': return []
if self.preset == 'tekken': return self.encode("[/INST]")
return [eos_id]
def prefix(self, bos_id:int|None) -> list[int]:
return ([] if bos_id is None else [bos_id]) + (self.encode("<sop>") if self.preset == 'glm4' else [])
@functools.cache
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
return freqs.cos().cat(freqs.sin(), dim=-1).contiguous()
class ExpertWeights:
"""Like nn.Linear but with num_experts dimension. Weight shape: (num_experts, out_features, in_features)."""
def __init__(self, num_experts:int, in_features:int, out_features:int):
self.weight = Tensor.zeros(num_experts, out_features, in_features)
def __call__(self, sel:Tensor, x:Tensor) -> Tensor:
# sel: (B, T, k), x: (B, T, 1, in) or (B, T, k, in) -> output: (B, T, k, out)
return (x.unsqueeze(-2) @ self.weight[sel].transpose(-1, -2)).squeeze(-2)
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
assert x.shape[-1] % 2 == 0
cos, sin = freqs_cis.reshape(1, 1, x.shape[2], -1).chunk(2, dim=-1)
x1, x2 = x.chunk(2, dim=-1)
return (x1 * cos - x2 * sin).cat(x2 * cos + x1 * sin, dim=-1)
def pairwise_topk(x: Tensor, k: int) -> tuple[Tensor, Tensor]:
n = x.shape[-1]
vals = Tensor.arange(n).reshape(1,1,n).cast(x.dtype).expand(x.shape)
cmp = (x.unsqueeze(-1) > x.unsqueeze(-2)) | ((x.unsqueeze(-1) == x.unsqueeze(-2)) & \
(Tensor.arange(n).reshape(1,1,n,1) < Tensor.arange(n).reshape(1,1,1,n)))
sel = Tensor.zeros_like(x).scatter(-1, cmp.sum(axis=-1).cast('int32'), vals)[:,:,n-k:].cast('int32')
return x.gather(-1, sel), sel
@dataclass(frozen=True)
class SSMConfig:
conv_kernel: int
state_size: int
group_count: int
time_step_rank: int
inner_size: int
@dataclass(frozen=True)
class TransformerConfig:
num_blocks: int
dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
vocab_size: int
head_dim: int
rope_theta: float
rope_dim: int
v_head_dim: int
max_context: int = 0
qk_norm: int = 0
num_experts: int = 0
num_experts_per_tok: int = 0
norm_topk_prob: bool = False
q_lora_rank: int = 0
kv_lora_rank: int = 0
shared_expert_dim: int = 0
full_attention_interval: int = 0
attn_output_gate: bool = False
ssm: SSMConfig|None = None
shared_expert_gate: bool = True
leading_dense_blocks: int = 0
dense_hidden_dim: int = 0
routed_scaling_factor: float = 1.0
class FFNBlock:
def __init__(self, config:TransformerConfig):
self.config = config
# --- RMSNorms --------------------------------------------------------
self.attn_norm = nn.RMSNorm(config.dim, config.norm_eps)
self.ffn_norm = nn.RMSNorm(config.dim, config.norm_eps)
# --- feed-forward (MoE or dense) -------------------------------------
if config.num_experts > 0:
self.ffn_gate_inp = nn.Linear(config.dim, config.num_experts, bias=False) # router
if config.kv_lora_rank > 0: self.exp_probs_b = {"bias": Tensor.zeros(config.num_experts)}
self.ffn_gate_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
self.ffn_up_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
self.ffn_down_exps = ExpertWeights(config.num_experts, config.hidden_dim, config.dim)
if config.shared_expert_dim > 0:
self.ffn_gate_shexp = nn.Linear(config.dim, config.shared_expert_dim, bias=False)
self.ffn_up_shexp = nn.Linear(config.dim, config.shared_expert_dim, bias=False)
self.ffn_down_shexp = nn.Linear(config.shared_expert_dim, config.dim, bias=False)
if config.shared_expert_gate: self.ffn_gate_inp_shexp = {"weight": Tensor.zeros(config.dim)}
else:
self.ffn_gate = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.ffn_up = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.ffn_down = nn.Linear(config.hidden_dim, config.dim, bias=False)
def _feed_forward(self, x:Tensor) -> Tensor:
if hasattr(self, 'ffn_gate_exps'):
h = x.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting
logits = self.ffn_gate_inp(x)
if hasattr(self, 'exp_probs_b'):
probs = logits.sigmoid()
_, sel = pairwise_topk(probs + self.exp_probs_b["bias"], self.config.num_experts_per_tok)
probs = probs.gather(-1, sel)
if self.config.norm_topk_prob: probs = probs / probs.sum(axis=-1, keepdim=True)
else:
vals, sel = pairwise_topk(logits, self.config.num_experts_per_tok)
probs = vals.softmax(-1) if self.config.norm_topk_prob else logits.softmax(-1).gather(-1, sel)
probs = probs * self.config.routed_scaling_factor
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, h).silu() * self.ffn_up_exps(sel, h)) # (B, T, k, D)
out = (x_down * probs.unsqueeze(-1)).sum(axis=2) # (B, T, D)
if hasattr(self, 'ffn_gate_shexp'):
shexp = self.ffn_down_shexp(self.ffn_gate_shexp(x).silu().contiguous() * self.ffn_up_shexp(x))
if hasattr(self, 'ffn_gate_inp_shexp'): shexp = shexp * (x * self.ffn_gate_inp_shexp["weight"]).sum(axis=-1, keepdim=True).sigmoid()
out = out + shexp
return out
# TODO: remove the need for this contiguous
return self.ffn_down(self.ffn_gate(x).silu().contiguous() * self.ffn_up(x))
# given the token-prefix match, return how much cached state this block can still reuse
def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return prefix_len
# return writes that reset this block's state after a cache mismatch
def _state_reset_ops(self) -> list[Tensor]: return []
def _init_state(self, x:Tensor): raise NotImplementedError
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor: raise NotImplementedError
def __call__(self, x: Tensor, start_pos: int|UOp):
self._init_state(x)
# we pass in the weights implicitly so we unpack the GGUF on the fly
@function(precompile=True, allow_implicit=True)
def _run(x:Tensor, start_pos:int|UOp):
h = x + self._attention(self.attn_norm(x), start_pos)
return (h + self._feed_forward(self.ffn_norm(h))).contiguous()
return _run(x, start_pos)
class TransformerBlock(FFNBlock):
def __init__(self, config:TransformerConfig):
super().__init__(config)
assert config.v_head_dim == config.head_dim, "TransformerBlock requires v_head_dim == head_dim"
# --- attention projections (all linear, bias-free) ------------------
q_proj_out = config.head_dim * config.n_heads * (2 if config.attn_output_gate else 1)
kv_proj_out = config.head_dim * config.n_kv_heads
self.attn_q = nn.Linear(config.dim, q_proj_out, bias=False)
self.attn_k = nn.Linear(config.dim, kv_proj_out, bias=False)
self.attn_v = nn.Linear(config.dim, kv_proj_out, bias=False)
self.attn_output = nn.Linear(config.head_dim * config.n_heads, config.dim, bias=False)
if config.qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(config.qk_norm, config.norm_eps), nn.RMSNorm(config.qk_norm, config.norm_eps)
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
q, k, v = self.attn_q(x), self.attn_k(x), self.attn_v(x)
if self.config.qk_norm and self.config.qk_norm != self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
B, T, _ = x.shape
if self.config.attn_output_gate:
qg = q.reshape(B, T, self.config.n_heads, 2, self.config.head_dim)
q, gate = qg[:, :, :, 0, :], qg[:, :, :, 1, :].reshape(B, T, self.config.n_heads * self.config.head_dim)
q = q.reshape(B, T, self.config.n_heads, self.config.head_dim).transpose(1, 2) # (B,H,T,Hd)
k = k.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
v = v.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
if self.config.qk_norm == self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
q = apply_rope(q[..., :self.config.rope_dim], self.freqs_cis[start_pos:start_pos+T]).cat(q[..., self.config.rope_dim:], dim=-1)
k = apply_rope(k[..., :self.config.rope_dim], self.freqs_cis[start_pos:start_pos+T]).cat(k[..., self.config.rope_dim:], dim=-1)
# NOTE: we don't want to change self.cache_kv, the function API doesn't support this well
assigned_kv = Tensor(self.cache_kv.uop.after(self.cache_kv[:, :, :, start_pos:start_pos+T, :].uop.store(Tensor.stack(k, v).uop)))
k = assigned_kv[0, :, :, 0:start_pos+T, :]
v = assigned_kv[1, :, :, 0:start_pos+T, :]
#self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
#k = self.cache_kv[0, :, :, 0:start_pos+T, :]
#v = self.cache_kv[1, :, :, 0:start_pos+T, :]
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
# TODO: this if statement should be removed and it shouldn't generate extra kernels
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
return self.attn_output(attn if not self.config.attn_output_gate else (attn * gate.sigmoid()))
def _init_state(self, x:Tensor):
if not hasattr(self, "cache_kv"):
# TODO: how is the dtype of this determined?
self.cache_kv = Tensor.empty(2, x.shape[0], self.config.n_kv_heads, self.config.max_context, self.config.head_dim, device=x.device)
self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta)
class MLATransformerBlock(FFNBlock):
def __init__(self, config:TransformerConfig):
super().__init__(config)
qk_nope_head_dim = config.head_dim - config.rope_dim
if config.q_lora_rank > 0:
self.attn_q_a = nn.Linear(config.dim, config.q_lora_rank, bias=False)
self.attn_q_a_norm = nn.RMSNorm(config.q_lora_rank, config.norm_eps)
self.attn_q_b = nn.Linear(config.q_lora_rank, config.n_heads * config.head_dim, bias=False)
else:
self.attn_q = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False)
self.attn_kv_a_mqa = nn.Linear(config.dim, config.kv_lora_rank + config.rope_dim, bias=False)
self.attn_kv_a_norm = nn.RMSNorm(config.kv_lora_rank, config.norm_eps)
self.attn_k_b = {"weight": Tensor.zeros(config.n_heads, config.kv_lora_rank, qk_nope_head_dim)}
self.attn_v_b = {"weight": Tensor.zeros(config.n_heads, config.v_head_dim, config.kv_lora_rank)}
self.attn_output = nn.Linear(config.n_heads * config.v_head_dim, config.dim, bias=False)
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
B, T, _ = x.shape
q_nope_head_dim = self.config.head_dim - self.config.rope_dim
q_proj = self.attn_q_b(self.attn_q_a_norm(self.attn_q_a(x))) if self.config.q_lora_rank > 0 else self.attn_q(x)
q = q_proj.reshape(B, T, self.config.n_heads, self.config.head_dim).transpose(1, 2)
q_nope, q_rope = q[..., :q_nope_head_dim], q[..., q_nope_head_dim:]
q = (q_nope @ self.attn_k_b["weight"].transpose(-1, -2)).cat(apply_rope(q_rope, self.freqs_cis[start_pos:start_pos+T]), dim=-1)
kv_a = self.attn_kv_a_mqa(x)
c_kv = self.attn_kv_a_norm(kv_a[..., :self.config.kv_lora_rank])
k_rope = apply_rope(
kv_a[..., self.config.kv_lora_rank:].reshape(B, T, 1, self.config.rope_dim).transpose(1, 2),
self.freqs_cis[start_pos:start_pos+T])
k_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank).cat(k_rope.reshape(B, 1, T, self.config.rope_dim), dim=-1)
v_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank)
k = Tensor(self.cache_k.uop.after(self.cache_k[:, :, start_pos:start_pos+T, :].uop.store(k_store.uop)))[:, :, 0:start_pos+T, :]
v = Tensor(self.cache_v.uop.after(self.cache_v[:, :, start_pos:start_pos+T, :].uop.store(v_store.uop)))[:, :, 0:start_pos+T, :]
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
attn = q @ k.transpose(-1, -2) * (1.0 / self.config.head_dim ** 0.5)
if mask is not None: attn = attn + mask
attn = attn.softmax(-1)
attn = ((attn @ v) @ self.attn_v_b["weight"].transpose(-1, -2)).transpose(1, 2).reshape(B, T, -1)
return self.attn_output(attn)
def _init_state(self, x:Tensor):
if not hasattr(self, "cache_k"):
self.cache_k = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank + self.config.rope_dim, device=x.device)
self.cache_v = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank, device=x.device)
self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta)
class GatedDeltaNetBlock(FFNBlock):
def __init__(self, config:TransformerConfig, ssm:SSMConfig):
super().__init__(config)
self.head_k_dim, self.num_k_heads, self.num_v_heads = ssm.state_size, ssm.group_count, ssm.time_step_rank
assert self.num_v_heads % self.num_k_heads == 0
self.head_v_dim, self.ssm_conv_kernel = ssm.inner_size // ssm.time_step_rank, ssm.conv_kernel
self.conv_channels, self.q_dim = ssm.inner_size + 2*ssm.group_count*ssm.state_size, ssm.state_size*ssm.group_count
self.attn_qkv, self.attn_gate = nn.Linear(config.dim, self.conv_channels, bias=False), nn.Linear(config.dim, ssm.inner_size, bias=False)
self.ssm_alpha, self.ssm_beta = nn.Linear(config.dim, self.num_v_heads, bias=False), nn.Linear(config.dim, self.num_v_heads, bias=False)
self.ssm_conv1d = {"weight": Tensor.zeros(self.conv_channels, self.ssm_conv_kernel)}
self.ssm_dt = {"bias": Tensor.zeros(self.num_v_heads)}
self.ssm_a = Tensor.zeros(self.num_v_heads)
self.ssm_norm, self.ssm_out = nn.RMSNorm(self.head_v_dim, config.norm_eps), nn.Linear(ssm.inner_size, config.dim, bias=False)
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
B, T, _ = x.shape
assert T == 1, "GatedDeltaNetBlock currently only supports T=1"
# input processing
x = x.half()
out_gate = self.attn_gate(x).reshape(B, 1, self.num_v_heads, self.head_v_dim)
beta = self.ssm_beta(x).sigmoid().reshape(B, self.num_v_heads, 1, 1)
alpha = ((self.ssm_alpha(x).float() + self.ssm_dt["bias"]).softplus() * self.ssm_a).reshape(B, self.num_v_heads, 1, 1).exp()
# qkv conv
conv_window = self.conv_state.cat(self.attn_qkv(x), dim=1)
conv_out = (conv_window * self.ssm_conv1d["weight"].T.unsqueeze(0)).sum(1).silu()
q, k, v = conv_out.split([self.q_dim, self.q_dim, self.conv_channels - 2*self.q_dim], dim=-1)
q = q.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1).repeat(1, self.num_v_heads//self.num_k_heads, 1)
k = k.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1).repeat(1, self.num_v_heads//self.num_k_heads, 1)
v = v.reshape(B, self.num_v_heads, self.head_v_dim)
q, k, v = q.mul(self.head_k_dim**-0.5).unsqueeze(-1), k.unsqueeze(-1), v.unsqueeze(-1)
# recurrent
recurrent_state = self.recurrent_state * alpha
recurrent_state = recurrent_state + ((v - recurrent_state@k) * beta)@k.transpose(-1, -2)
# store the updated state
conv_state_store = self.conv_state.uop.store(conv_window[:, 1:, :].cast(self.conv_state.dtype).uop)
recurrent_state_store = self.recurrent_state.uop.store(recurrent_state.cast(self.recurrent_state.dtype).uop)
recurrent_state = Tensor(self.recurrent_state.uop.after(recurrent_state_store, conv_state_store))
# output
core_attn_out = self.ssm_norm((recurrent_state@q).squeeze(-1).reshape(B, 1, self.num_v_heads, self.head_v_dim))
return self.ssm_out((core_attn_out * out_gate.silu()).reshape(B, 1, -1).cast(x.dtype))
# recurrent state can't be partially reused after divergence, force a full rebuild
def _state_reset_ops(self):
return [self.conv_state.assign(Tensor.zeros_like(self.conv_state)),
self.recurrent_state.assign(Tensor.zeros_like(self.recurrent_state))] if hasattr(self, "conv_state") else []
def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return 0 if prefix_len != cached_len else prefix_len
def _init_state(self, x):
if not hasattr(self, "conv_state"):
self.conv_state = Tensor.zeros(x.shape[0], self.ssm_conv_kernel-1, self.conv_channels, device=x.device).clone()
self.recurrent_state = Tensor.zeros(x.shape[0], self.num_v_heads, self.head_v_dim, self.head_v_dim, device=x.device).clone()
class Transformer:
def __init__(self, config:TransformerConfig):
dense_config = replace(config, num_experts=0, num_experts_per_tok=0, shared_expert_dim=0, hidden_dim=config.dense_hidden_dim or config.hidden_dim)
if config.ssm: config = replace(config, qk_norm=config.head_dim)
block_cls = MLATransformerBlock if config.kv_lora_rank > 0 else TransformerBlock
self.blk:list[FFNBlock] = [GatedDeltaNetBlock(config, config.ssm) if config.ssm and (i+1) % config.full_attention_interval != 0 else
block_cls(dense_config if i < config.leading_dense_blocks else config) for i in range(config.num_blocks)]
self.token_embd = nn.Embedding(config.vocab_size, config.dim)
self.output_norm = nn.RMSNorm(config.dim, config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
self.max_context = config.max_context
self.has_recurrent_block = any(isinstance(b, GatedDeltaNetBlock) for b in self.blk)
self._cached_tokens: list[int] = []
# we specialize the JIT for prefill and rollout
self.prefill_jit = TinyJit(self.forward)
self.rollout_jit = TinyJit(self.forward)
def forward(self, tokens:Tensor, start_pos:int|UOp, temperature:Tensor) -> Tensor:
x = self.token_embd(tokens).float() # (B, T, D)
for block in self.blk: x = block(x, start_pos)
logits = self.output(self.output_norm(x))[:, -1, :]
# Gumbel-max trick: argmax(logits/temp - log(-log(uniform))) is equivalent to sampling from softmax(logits/temp)
return (logits / temperature.maximum(1e-12) - (Tensor.rand_like(logits).maximum(1e-12).log().neg()).log()).argmax(-1, keepdim=True)
def __call__(self, tokens:Tensor, start_pos:int|UOp, temperature:Tensor) -> Tensor:
return (self.prefill_jit if resolve(tokens.shape[1] != 1) else self.rollout_jit)(tokens.contiguous(), start_pos, temperature)
@staticmethod
def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALIZE", 0))) -> tuple[Transformer, dict]:
# TODO: remove the need for copy to default device
kv, state_dict = nn.state.gguf_load(gguf.to(None).realize())
# all state items should be float16, not float32
state_dict = {k:v.cast('float16') if getenv("HALF", 1) else v for k,v in state_dict.items()}
# some models like Llama 3.2 don't have an output.weight, they just tie to the token_embd.weight
if 'output.weight' not in state_dict: state_dict['output.weight'] = state_dict['token_embd.weight']
arch = kv['general.architecture']
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
ssm = None
if arch in ('qwen35', 'qwen35moe'):
ssm = SSMConfig(**{k: kv[f'{arch}.ssm.{k}'] for k in ('conv_kernel','state_size','group_count','time_step_rank','inner_size')})
state_dict = {k.replace('post_attention_norm', 'ffn_norm'):v for k,v in state_dict.items()}
kv_lora_rank = kv.get(f'{arch}.attention.kv_lora_rank', 0)
head_dim = kv.get(f'{arch}.attention.key_length_mla', kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads))
rope_dim = kv.get(f'{arch}.rope.dimension_count', head_dim)
# Permute RoPE weights from interleaved to half-split layout.
for name in state_dict:
if ('attn_q.weight' in name or 'attn_q_b.weight' in name) and (arch == 'llama' or kv_lora_rank):
w = state_dict[name].reshape(n_heads, state_dict[name].shape[0]//n_heads, -1)
prefix = head_dim-rope_dim
state_dict[name] = w[:, :prefix].cat(w[:, prefix:].rearrange("n (h two) d -> n (two h) d", two=2), dim=1).reshape(-1, w.shape[-1])
elif arch == 'llama' and 'attn_k.weight' in name:
w = state_dict[name].reshape(n_kv_heads, state_dict[name].shape[0]//n_kv_heads, -1)
state_dict[name] = w.rearrange("n (h two) d -> n (two h) d", two=2).reshape(-1, w.shape[-1])
elif kv_lora_rank and 'attn_kv_a_mqa.weight' in name:
state_dict[name] = state_dict[name][:kv_lora_rank].cat(state_dict[name][kv_lora_rank:].rearrange("(h two) d -> (two h) d", two=2), dim=0)
config = TransformerConfig(
num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv.get(f'{arch}.feed_forward_length', 0)),
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
vocab_size=len(kv['tokenizer.ggml.tokens']),
head_dim=head_dim,
rope_theta=kv[f'{arch}.rope.freq_base'],
rope_dim=rope_dim,
v_head_dim=kv.get(f'{arch}.attention.value_length_mla', kv.get(f'{arch}.attention.value_length', head_dim)),
max_context=max_context,
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0),
norm_topk_prob=kv.get(f'{arch}.expert_weights_norm', arch in ('qwen3moe', 'qwen35moe')),
kv_lora_rank=kv_lora_rank, q_lora_rank=kv.get(f'{arch}.attention.q_lora_rank', 0),
leading_dense_blocks=kv.get(f'{arch}.leading_dense_block_count', 0),
shared_expert_dim=kv.get(
f'{arch}.expert_shared_feed_forward_length',
kv.get(f'{arch}.expert_shared_count', 0) * kv.get(f'{arch}.expert_feed_forward_length', 0)),
shared_expert_gate=f"blk.{kv.get(f'{arch}.leading_dense_block_count', 0)}.ffn_gate_inp_shexp.weight" in state_dict,
dense_hidden_dim=kv.get(f'{arch}.feed_forward_length', 0) if kv.get(f'{arch}.leading_dense_block_count', 0) else 0,
routed_scaling_factor=kv.get(f'{arch}.expert_weights_scale', 1.0), attn_output_gate=arch in ('qwen35', 'qwen35moe'), ssm=ssm,
full_attention_interval=kv.get(f'{arch}.full_attention_interval', 0))
model = Transformer(config)
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
if realize:
for s in (params:=nn.state.get_parameters(model)): s.replace(s.contiguous())
Tensor.realize(*params)
return model, kv
def get_start_pos(self, tokens:list[int]) -> int:
prefix_len = sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens[:-1], self._cached_tokens)))
return min(block._reusable_prefix_len(prefix_len, len(self._cached_tokens)) for block in self.blk)
def generate(self, tokens:list[int], chunk_size:int=32, temperature:float=0.0):
if self.has_recurrent_block: chunk_size = 1
v_start_pos = UOp.variable("start_pos", 0, self.max_context-1)
v_toks = UOp.variable("toks", 1, chunk_size)
# TODO: use UOp.variable for temperature once float variables are supported
temp = Tensor(temperature).contiguous()
# assign all input tokens once, then slice from start_pos for the model call
t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context)
# recompute start_pos from what's currently valid in the caches
start_pos = self.get_start_pos(tokens)
if start_pos < len(self._cached_tokens) and (resets := [r for b in self.blk for r in b._state_reset_ops()]): Tensor.realize(*resets)
out, prompt_len = None, len(tokens)
while len(tokens) < self.max_context:
sp, nt = v_start_pos.bind(start_pos), v_toks.bind(min(chunk_size, len(tokens) - start_pos))
out = self(t[:, sp:sp+nt] if start_pos < prompt_len or out is None else out, sp, temp).realize()
start_pos += nt.val
# chunked prefill: keep processing until all prompt tokens are consumed
if start_pos < len(tokens): continue
tokens.append(int(out.item()))
self._cached_tokens = tokens[:-1]
yield tokens[-1]
return [self.eos_id]
def prefix(self) -> list[int]:
return ([] if self.bos_id is None else [self.bos_id]) + (self.encode("<sop>") if self.preset == 'glm4' else [])
def is_end(self, token_id:int) -> bool: return token_id in (self.eos_id, self.eot_id)
models = {
"llama3.2:1b": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
@ -506,62 +106,14 @@ models = {
# *** simple OpenAI API compatible server with web interface on http://localhost:8000/ ***
CHAT_HTML = b'''<!DOCTYPE html><html><head><title>tinygrad chat</title><style>
* { margin: 0 }
body { background: #212121; color: #e3e3e3; font-family: system-ui;
height: 100vh; display: flex; flex-direction: column }
#chat { flex: 1; overflow-y: auto; padding: 20px }
.msg { padding: 10px 16px; margin: 8px 0; white-space: pre-wrap; border-radius: 18px }
.user { background: #2f2f2f; margin-left: auto; width: fit-content; max-width: 70% }
#input { max-width: 768px; width: 100%; margin: 20px auto; padding: 14px 20px;
background: #2f2f2f; color: inherit; font: inherit;
border: none; outline: none; resize: none; border-radius: 24px; field-sizing: content }
</style></head><body><div id="chat"></div>
<textarea id="input" rows="1" placeholder="Ask anything" autofocus></textarea>
<script>
input.onkeydown = (e) => { if (e.key === 'Enter' && !e.shiftKey && !e.isComposing) { e.preventDefault(); send() } }
const msgs = [];
async function send() {
if (!input.value.trim()) return;
msgs.push({role: 'user', content: input.value.trim()});
chat.innerHTML += '<div class="msg user">' + input.value.trim().replace(/</g, '&lt;') + '</div>';
input.value = '';
const d = document.createElement('div'); d.className = 'msg'; chat.appendChild(d);
const r = await fetch('/v1/chat/completions', {method: 'POST', headers: {'Content-Type': 'application/json'},
body: JSON.stringify({model: 'llama', messages: msgs, stream: true, temperature: 0.7})});
let buf = '';
for (const rd = r.body.getReader(), dec = new TextDecoder();;) {
const {done, value} = await rd.read();
if (done) break;
buf += dec.decode(value, {stream: true});
const lines = buf.split('\\n');
buf = lines.pop();
for (const ln of lines)
if (ln.startsWith('data: ') && !ln.includes('[DONE]'))
try { d.textContent += JSON.parse(ln.slice(6)).choices[0]?.delta?.content || '' } catch {}
chat.scrollTop = chat.scrollHeight;
}
msgs.push({role: 'assistant', content: d.textContent});
}
</script></body></html>'''
class LLMServer(TCPServerWithReuse):
model: Transformer
model_name: str
tok: SimpleTokenizer
# TODO: tastefully move these into tokenizer
bos_id: int|None
eos_id: int
eot_id: int|None
class Handler(HTTPRequestHandler):
server: LLMServer
def log_request(self, code='-', size='-'): pass
def do_GET(self):
if self.path == "/v1/models": self.send_data(json.dumps({"object":"list","data":[{"id":self.server.model_name,"object":"model"}]}).encode())
else: self.send_data(CHAT_HTML, content_type="text/html")
else: self.send_data((pathlib.Path(__file__).parent / "chat.html").read_bytes(), content_type="text/html")
def run_model(self, ids:list[int], model_name:str, include_usage=False, max_tokens:int|None=None, temperature:float=0.0):
model, tok, eos_id, eot_id = self.server.model, self.server.tok, self.server.eos_id, self.server.eot_id
model, tok = self.server.model, self.server.tok
cache_start_pos = model.get_start_pos(ids)
stderr_log(f"{self.path} {colored('--', 'BLACK')} "
f"in:{colored(f'{cache_start_pos:5d}', 'green')} +{len(ids)-cache_start_pos:5d} {colored('--', 'BLACK')} ")
@ -573,7 +125,7 @@ class Handler(HTTPRequestHandler):
dec = tok.stream_decoder()
for next_id in model.generate(ids, temperature=temperature):
if len(out) == 0: stderr_log(f"prefill:{(len(ids)-cache_start_pos)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ")
if next_id in (eos_id, eot_id): break
if tok.is_end(next_id): break
out.append(next_id)
yield {"choices": [{"index":0, "delta":{"content":dec(next_id)}, "finish_reason":None}], **tmpl}
if max_tokens is not None and len(out) >= max_tokens:
@ -588,13 +140,13 @@ class Handler(HTTPRequestHandler):
f"out:{len(out):5d} {colored('--', 'BLACK')} total:{et-st:6.2f}s\n")
def do_POST(self):
tok, bos_id, eos_id = self.server.tok, self.server.bos_id, self.server.eos_id
tok = self.server.tok
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
if DEBUG >= 1: print(json.dumps(body, indent=2))
if self.path == "/v1/chat/completions":
# extract tokens, last assistant message is treated as prefill
ids: list[int] = tok.prefix(bos_id)
ids: list[int] = tok.prefix()
for i, msg in enumerate(body["messages"]):
ids += tok.role(msg["role"])
content = msg["content"]
@ -605,7 +157,7 @@ class Handler(HTTPRequestHandler):
else: raise RuntimeError(f"unhandled type: {c['type']}")
else: raise RuntimeError(f"unknown content type: {type(content)}")
if msg["role"] == "assistant" and i == len(body["messages"]) - 1: break
ids += tok.end_turn(eos_id)
ids += tok.end_turn()
else: ids += tok.role("assistant")
# reply
@ -623,6 +175,11 @@ class Handler(HTTPRequestHandler):
else:
raise RuntimeError(f"unhandled path {self.path}")
class LLMServer(TCPServerWithReuse):
def __init__(self, server_address:tuple, model:Transformer, model_name:str, tok:SimpleTokenizer):
self.model, self.model_name, self.tok = model, model_name, tok
super().__init__(server_address, Handler)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", "-m", default=list(models.keys())[0], help=f"Model choice ({', '.join(models.keys())}) or path to a local GGUF file")
@ -644,9 +201,6 @@ def main():
gc.collect()
tok = SimpleTokenizer.from_gguf_kv(kv)
bos_id: int|None = kv.get('tokenizer.ggml.bos_token_id') if kv.get('tokenizer.ggml.add_bos_token', True) else None
eos_id: int = kv['tokenizer.ggml.eos_token_id']
eot_id: int|None = kv.get('tokenizer.ggml.eot_token_id')
# warmup the JIT
if args.warmup or args.serve:
@ -655,15 +209,11 @@ def main():
for _ in range(2): list(zip(range(2), model.generate([0])))
# start server
if args.serve:
server = LLMServer(('', args.serve), Handler)
server.model, server.model_name, server.tok = model, model_name, tok
server.bos_id, server.eos_id, server.eot_id = bos_id, eos_id, eot_id
server.serve_forever()
if args.serve: LLMServer(('', args.serve), model, model_name, tok).serve_forever()
# do benchmark
if args.benchmark is not None:
gen = model.generate(toks:=[bos_id or 0])
gen = model.generate(toks:=[tok.bos_id or 0])
for _ in range(args.benchmark):
GlobalCounters.reset()
with Timing(on_exit=lambda x: f", {1e9/x:6.2f} tok/s, {GlobalCounters.global_mem/x:7.2f} GB/s,"
@ -672,16 +222,16 @@ def main():
exit(0)
# interactive chat
ids: list[int] = tok.prefix(bos_id)
ids: list[int] = tok.prefix()
while 1:
try:
ids += tok.role("user") + tok.encode(input('>>> ')) + tok.end_turn(eos_id) + tok.role("assistant")
ids += tok.role("user") + tok.encode(input('>>> ')) + tok.end_turn() + tok.role("assistant")
except EOFError:
break
dec = tok.stream_decoder()
for next_id in model.generate(ids):
sys.stdout.write(dec(next_id) if next_id not in (eos_id, eot_id) else dec() + "\n\n")
sys.stdout.write(dec(next_id) if not tok.is_end(next_id) else dec() + "\n\n")
sys.stdout.flush()
if next_id in (eos_id, eot_id): break
if tok.is_end(next_id): break
if __name__ == "__main__": main()

409
tinygrad/llm/model.py Normal file
View file

@ -0,0 +1,409 @@
from __future__ import annotations
import functools, itertools
from dataclasses import dataclass, replace
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function
from tinygrad.uop.ops import resolve
@functools.cache
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
return freqs.cos().cat(freqs.sin(), dim=-1).contiguous()
class ExpertWeights:
"""Like nn.Linear but with num_experts dimension. Weight shape: (num_experts, out_features, in_features)."""
def __init__(self, num_experts:int, in_features:int, out_features:int):
self.weight = Tensor.zeros(num_experts, out_features, in_features)
def __call__(self, sel:Tensor, x:Tensor) -> Tensor:
# sel: (B, T, k), x: (B, T, 1, in) or (B, T, k, in) -> output: (B, T, k, out)
return (x.unsqueeze(-2) @ self.weight[sel].transpose(-1, -2)).squeeze(-2)
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
assert x.shape[-1] % 2 == 0
cos, sin = freqs_cis.reshape(1, 1, x.shape[2], -1).chunk(2, dim=-1)
x1, x2 = x.chunk(2, dim=-1)
return (x1 * cos - x2 * sin).cat(x2 * cos + x1 * sin, dim=-1)
def pairwise_topk(x: Tensor, k: int) -> tuple[Tensor, Tensor]:
n = x.shape[-1]
vals = Tensor.arange(n).reshape(1,1,n).cast(x.dtype).expand(x.shape)
cmp = (x.unsqueeze(-1) > x.unsqueeze(-2)) | ((x.unsqueeze(-1) == x.unsqueeze(-2)) & \
(Tensor.arange(n).reshape(1,1,n,1) < Tensor.arange(n).reshape(1,1,1,n)))
sel = Tensor.zeros_like(x).scatter(-1, cmp.sum(axis=-1).cast('int32'), vals)[:,:,n-k:].cast('int32')
return x.gather(-1, sel), sel
@dataclass(frozen=True)
class SSMConfig:
conv_kernel: int
state_size: int
group_count: int
time_step_rank: int
inner_size: int
@dataclass(frozen=True)
class TransformerConfig:
num_blocks: int
dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
vocab_size: int
head_dim: int
rope_theta: float
rope_dim: int
v_head_dim: int
max_context: int = 0
qk_norm: int = 0
num_experts: int = 0
num_experts_per_tok: int = 0
norm_topk_prob: bool = False
q_lora_rank: int = 0
kv_lora_rank: int = 0
shared_expert_dim: int = 0
full_attention_interval: int = 0
attn_output_gate: bool = False
ssm: SSMConfig|None = None
shared_expert_gate: bool = True
leading_dense_blocks: int = 0
dense_hidden_dim: int = 0
routed_scaling_factor: float = 1.0
class FFNBlock:
def __init__(self, config:TransformerConfig):
self.config = config
# --- RMSNorms --------------------------------------------------------
self.attn_norm = nn.RMSNorm(config.dim, config.norm_eps)
self.ffn_norm = nn.RMSNorm(config.dim, config.norm_eps)
# --- feed-forward (MoE or dense) -------------------------------------
if config.num_experts > 0:
self.ffn_gate_inp = nn.Linear(config.dim, config.num_experts, bias=False) # router
if config.kv_lora_rank > 0: self.exp_probs_b = {"bias": Tensor.zeros(config.num_experts)}
self.ffn_gate_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
self.ffn_up_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
self.ffn_down_exps = ExpertWeights(config.num_experts, config.hidden_dim, config.dim)
if config.shared_expert_dim > 0:
self.ffn_gate_shexp = nn.Linear(config.dim, config.shared_expert_dim, bias=False)
self.ffn_up_shexp = nn.Linear(config.dim, config.shared_expert_dim, bias=False)
self.ffn_down_shexp = nn.Linear(config.shared_expert_dim, config.dim, bias=False)
if config.shared_expert_gate: self.ffn_gate_inp_shexp = {"weight": Tensor.zeros(config.dim)}
else:
self.ffn_gate = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.ffn_up = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.ffn_down = nn.Linear(config.hidden_dim, config.dim, bias=False)
def _feed_forward(self, x:Tensor) -> Tensor:
if hasattr(self, 'ffn_gate_exps'):
h = x.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting
logits = self.ffn_gate_inp(x)
if hasattr(self, 'exp_probs_b'):
probs = logits.sigmoid()
_, sel = pairwise_topk(probs + self.exp_probs_b["bias"], self.config.num_experts_per_tok)
probs = probs.gather(-1, sel)
if self.config.norm_topk_prob: probs = probs / probs.sum(axis=-1, keepdim=True)
else:
vals, sel = pairwise_topk(logits, self.config.num_experts_per_tok)
probs = vals.softmax(-1) if self.config.norm_topk_prob else logits.softmax(-1).gather(-1, sel)
probs = probs * self.config.routed_scaling_factor
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, h).silu() * self.ffn_up_exps(sel, h)) # (B, T, k, D)
out = (x_down * probs.unsqueeze(-1)).sum(axis=2) # (B, T, D)
if hasattr(self, 'ffn_gate_shexp'):
shexp = self.ffn_down_shexp(self.ffn_gate_shexp(x).silu().contiguous() * self.ffn_up_shexp(x))
if hasattr(self, 'ffn_gate_inp_shexp'): shexp = shexp * (x * self.ffn_gate_inp_shexp["weight"]).sum(axis=-1, keepdim=True).sigmoid()
out = out + shexp
return out
# TODO: remove the need for this contiguous
return self.ffn_down(self.ffn_gate(x).silu().contiguous() * self.ffn_up(x))
# given the token-prefix match, return how much cached state this block can still reuse
def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return prefix_len
# return writes that reset this block's state after a cache mismatch
def _state_reset_ops(self) -> list[Tensor]: return []
def _init_state(self, x:Tensor): raise NotImplementedError
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor: raise NotImplementedError
def __call__(self, x: Tensor, start_pos: int|UOp):
self._init_state(x)
# we pass in the weights implicitly so we unpack the GGUF on the fly
@function(precompile=True, allow_implicit=True)
def _run(x:Tensor, start_pos:int|UOp):
h = x + self._attention(self.attn_norm(x), start_pos)
return (h + self._feed_forward(self.ffn_norm(h))).contiguous()
return _run(x, start_pos)
class TransformerBlock(FFNBlock):
def __init__(self, config:TransformerConfig):
super().__init__(config)
assert config.v_head_dim == config.head_dim, "TransformerBlock requires v_head_dim == head_dim"
# --- attention projections (all linear, bias-free) ------------------
q_proj_out = config.head_dim * config.n_heads * (2 if config.attn_output_gate else 1)
kv_proj_out = config.head_dim * config.n_kv_heads
self.attn_q = nn.Linear(config.dim, q_proj_out, bias=False)
self.attn_k = nn.Linear(config.dim, kv_proj_out, bias=False)
self.attn_v = nn.Linear(config.dim, kv_proj_out, bias=False)
self.attn_output = nn.Linear(config.head_dim * config.n_heads, config.dim, bias=False)
if config.qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(config.qk_norm, config.norm_eps), nn.RMSNorm(config.qk_norm, config.norm_eps)
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
q, k, v = self.attn_q(x), self.attn_k(x), self.attn_v(x)
if self.config.qk_norm and self.config.qk_norm != self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
B, T, _ = x.shape
if self.config.attn_output_gate:
qg = q.reshape(B, T, self.config.n_heads, 2, self.config.head_dim)
q, gate = qg[:, :, :, 0, :], qg[:, :, :, 1, :].reshape(B, T, self.config.n_heads * self.config.head_dim)
q = q.reshape(B, T, self.config.n_heads, self.config.head_dim).transpose(1, 2) # (B,H,T,Hd)
k = k.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
v = v.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
if self.config.qk_norm == self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
q = apply_rope(q[..., :self.config.rope_dim], self.freqs_cis[start_pos:start_pos+T]).cat(q[..., self.config.rope_dim:], dim=-1)
k = apply_rope(k[..., :self.config.rope_dim], self.freqs_cis[start_pos:start_pos+T]).cat(k[..., self.config.rope_dim:], dim=-1)
# NOTE: we don't want to change self.cache_kv, the function API doesn't support this well
assigned_kv = Tensor(self.cache_kv.uop.after(self.cache_kv[:, :, :, start_pos:start_pos+T, :].uop.store(Tensor.stack(k, v).uop)))
k = assigned_kv[0, :, :, 0:start_pos+T, :]
v = assigned_kv[1, :, :, 0:start_pos+T, :]
#self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
#k = self.cache_kv[0, :, :, 0:start_pos+T, :]
#v = self.cache_kv[1, :, :, 0:start_pos+T, :]
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
# TODO: this if statement should be removed and it shouldn't generate extra kernels
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
return self.attn_output(attn if not self.config.attn_output_gate else (attn * gate.sigmoid()))
def _init_state(self, x:Tensor):
if not hasattr(self, "cache_kv"):
# TODO: how is the dtype of this determined?
self.cache_kv = Tensor.empty(2, x.shape[0], self.config.n_kv_heads, self.config.max_context, self.config.head_dim, device=x.device)
self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta)
class MLATransformerBlock(FFNBlock):
def __init__(self, config:TransformerConfig):
super().__init__(config)
qk_nope_head_dim = config.head_dim - config.rope_dim
if config.q_lora_rank > 0:
self.attn_q_a = nn.Linear(config.dim, config.q_lora_rank, bias=False)
self.attn_q_a_norm = nn.RMSNorm(config.q_lora_rank, config.norm_eps)
self.attn_q_b = nn.Linear(config.q_lora_rank, config.n_heads * config.head_dim, bias=False)
else:
self.attn_q = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False)
self.attn_kv_a_mqa = nn.Linear(config.dim, config.kv_lora_rank + config.rope_dim, bias=False)
self.attn_kv_a_norm = nn.RMSNorm(config.kv_lora_rank, config.norm_eps)
self.attn_k_b = {"weight": Tensor.zeros(config.n_heads, config.kv_lora_rank, qk_nope_head_dim)}
self.attn_v_b = {"weight": Tensor.zeros(config.n_heads, config.v_head_dim, config.kv_lora_rank)}
self.attn_output = nn.Linear(config.n_heads * config.v_head_dim, config.dim, bias=False)
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
B, T, _ = x.shape
q_nope_head_dim = self.config.head_dim - self.config.rope_dim
q_proj = self.attn_q_b(self.attn_q_a_norm(self.attn_q_a(x))) if self.config.q_lora_rank > 0 else self.attn_q(x)
q = q_proj.reshape(B, T, self.config.n_heads, self.config.head_dim).transpose(1, 2)
q_nope, q_rope = q[..., :q_nope_head_dim], q[..., q_nope_head_dim:]
q = (q_nope @ self.attn_k_b["weight"].transpose(-1, -2)).cat(apply_rope(q_rope, self.freqs_cis[start_pos:start_pos+T]), dim=-1)
kv_a = self.attn_kv_a_mqa(x)
c_kv = self.attn_kv_a_norm(kv_a[..., :self.config.kv_lora_rank])
k_rope = apply_rope(
kv_a[..., self.config.kv_lora_rank:].reshape(B, T, 1, self.config.rope_dim).transpose(1, 2),
self.freqs_cis[start_pos:start_pos+T])
k_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank).cat(k_rope.reshape(B, 1, T, self.config.rope_dim), dim=-1)
v_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank)
k = Tensor(self.cache_k.uop.after(self.cache_k[:, :, start_pos:start_pos+T, :].uop.store(k_store.uop)))[:, :, 0:start_pos+T, :]
v = Tensor(self.cache_v.uop.after(self.cache_v[:, :, start_pos:start_pos+T, :].uop.store(v_store.uop)))[:, :, 0:start_pos+T, :]
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
attn = q @ k.transpose(-1, -2) * (1.0 / self.config.head_dim ** 0.5)
if mask is not None: attn = attn + mask
attn = attn.softmax(-1)
attn = ((attn @ v) @ self.attn_v_b["weight"].transpose(-1, -2)).transpose(1, 2).reshape(B, T, -1)
return self.attn_output(attn)
def _init_state(self, x:Tensor):
if not hasattr(self, "cache_k"):
self.cache_k = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank + self.config.rope_dim, device=x.device)
self.cache_v = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank, device=x.device)
self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta)
class GatedDeltaNetBlock(FFNBlock):
def __init__(self, config:TransformerConfig, ssm:SSMConfig):
super().__init__(config)
self.head_k_dim, self.num_k_heads, self.num_v_heads = ssm.state_size, ssm.group_count, ssm.time_step_rank
assert self.num_v_heads % self.num_k_heads == 0
self.head_v_dim, self.ssm_conv_kernel = ssm.inner_size // ssm.time_step_rank, ssm.conv_kernel
self.conv_channels, self.q_dim = ssm.inner_size + 2*ssm.group_count*ssm.state_size, ssm.state_size*ssm.group_count
self.attn_qkv, self.attn_gate = nn.Linear(config.dim, self.conv_channels, bias=False), nn.Linear(config.dim, ssm.inner_size, bias=False)
self.ssm_alpha, self.ssm_beta = nn.Linear(config.dim, self.num_v_heads, bias=False), nn.Linear(config.dim, self.num_v_heads, bias=False)
self.ssm_conv1d = {"weight": Tensor.zeros(self.conv_channels, self.ssm_conv_kernel)}
self.ssm_dt = {"bias": Tensor.zeros(self.num_v_heads)}
self.ssm_a = Tensor.zeros(self.num_v_heads)
self.ssm_norm, self.ssm_out = nn.RMSNorm(self.head_v_dim, config.norm_eps), nn.Linear(ssm.inner_size, config.dim, bias=False)
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
B, T, _ = x.shape
assert T == 1, "GatedDeltaNetBlock currently only supports T=1"
# input processing
x = x.half()
out_gate = self.attn_gate(x).reshape(B, 1, self.num_v_heads, self.head_v_dim)
beta = self.ssm_beta(x).sigmoid().reshape(B, self.num_v_heads, 1, 1)
alpha = ((self.ssm_alpha(x).float() + self.ssm_dt["bias"]).softplus() * self.ssm_a).reshape(B, self.num_v_heads, 1, 1).exp()
# qkv conv
conv_window = self.conv_state.cat(self.attn_qkv(x), dim=1)
conv_out = (conv_window * self.ssm_conv1d["weight"].T.unsqueeze(0)).sum(1).silu()
q, k, v = conv_out.split([self.q_dim, self.q_dim, self.conv_channels - 2*self.q_dim], dim=-1)
q = q.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1).repeat(1, self.num_v_heads//self.num_k_heads, 1)
k = k.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1).repeat(1, self.num_v_heads//self.num_k_heads, 1)
v = v.reshape(B, self.num_v_heads, self.head_v_dim)
q, k, v = q.mul(self.head_k_dim**-0.5).unsqueeze(-1), k.unsqueeze(-1), v.unsqueeze(-1)
# recurrent
recurrent_state = self.recurrent_state * alpha
recurrent_state = recurrent_state + ((v - recurrent_state@k) * beta)@k.transpose(-1, -2)
# store the updated state
conv_state_store = self.conv_state.uop.store(conv_window[:, 1:, :].cast(self.conv_state.dtype).uop)
recurrent_state_store = self.recurrent_state.uop.store(recurrent_state.cast(self.recurrent_state.dtype).uop)
recurrent_state = Tensor(self.recurrent_state.uop.after(recurrent_state_store, conv_state_store))
# output
core_attn_out = self.ssm_norm((recurrent_state@q).squeeze(-1).reshape(B, 1, self.num_v_heads, self.head_v_dim))
return self.ssm_out((core_attn_out * out_gate.silu()).reshape(B, 1, -1).cast(x.dtype))
# recurrent state can't be partially reused after divergence, force a full rebuild
def _state_reset_ops(self):
return [self.conv_state.assign(Tensor.zeros_like(self.conv_state)),
self.recurrent_state.assign(Tensor.zeros_like(self.recurrent_state))] if hasattr(self, "conv_state") else []
def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return 0 if prefix_len != cached_len else prefix_len
def _init_state(self, x):
if not hasattr(self, "conv_state"):
self.conv_state = Tensor.zeros(x.shape[0], self.ssm_conv_kernel-1, self.conv_channels, device=x.device).clone()
self.recurrent_state = Tensor.zeros(x.shape[0], self.num_v_heads, self.head_v_dim, self.head_v_dim, device=x.device).clone()
class Transformer:
def __init__(self, config:TransformerConfig):
dense_config = replace(config, num_experts=0, num_experts_per_tok=0, shared_expert_dim=0, hidden_dim=config.dense_hidden_dim or config.hidden_dim)
if config.ssm: config = replace(config, qk_norm=config.head_dim)
block_cls = MLATransformerBlock if config.kv_lora_rank > 0 else TransformerBlock
self.blk:list[FFNBlock] = [GatedDeltaNetBlock(config, config.ssm) if config.ssm and (i+1) % config.full_attention_interval != 0 else
block_cls(dense_config if i < config.leading_dense_blocks else config) for i in range(config.num_blocks)]
self.token_embd = nn.Embedding(config.vocab_size, config.dim)
self.output_norm = nn.RMSNorm(config.dim, config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
self.max_context = config.max_context
self.has_recurrent_block = any(isinstance(b, GatedDeltaNetBlock) for b in self.blk)
self._cached_tokens: list[int] = []
# we specialize the JIT for prefill and rollout
self.prefill_jit = TinyJit(self.forward)
self.rollout_jit = TinyJit(self.forward)
def forward(self, tokens:Tensor, start_pos:int|UOp, temperature:Tensor) -> Tensor:
x = self.token_embd(tokens).float() # (B, T, D)
for block in self.blk: x = block(x, start_pos)
logits = self.output(self.output_norm(x))[:, -1, :]
# Gumbel-max trick: argmax(logits/temp - log(-log(uniform))) is equivalent to sampling from softmax(logits/temp)
return (logits / temperature.maximum(1e-12) - (Tensor.rand_like(logits).maximum(1e-12).log().neg()).log()).argmax(-1, keepdim=True)
def __call__(self, tokens:Tensor, start_pos:int|UOp, temperature:Tensor) -> Tensor:
return (self.prefill_jit if resolve(tokens.shape[1] != 1) else self.rollout_jit)(tokens.contiguous(), start_pos, temperature)
@staticmethod
def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALIZE", 0))) -> tuple[Transformer, dict]:
# TODO: remove the need for copy to default device
kv, state_dict = nn.state.gguf_load(gguf.to(None).realize())
# all state items should be float16, not float32
state_dict = {k:v.cast('float16') if getenv("HALF", 1) else v for k,v in state_dict.items()}
# some models like Llama 3.2 don't have an output.weight, they just tie to the token_embd.weight
if 'output.weight' not in state_dict: state_dict['output.weight'] = state_dict['token_embd.weight']
arch = kv['general.architecture']
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
ssm = None
if arch in ('qwen35', 'qwen35moe'):
ssm = SSMConfig(**{k: kv[f'{arch}.ssm.{k}'] for k in ('conv_kernel','state_size','group_count','time_step_rank','inner_size')})
state_dict = {k.replace('post_attention_norm', 'ffn_norm'):v for k,v in state_dict.items()}
kv_lora_rank = kv.get(f'{arch}.attention.kv_lora_rank', 0)
head_dim = kv.get(f'{arch}.attention.key_length_mla', kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads))
rope_dim = kv.get(f'{arch}.rope.dimension_count', head_dim)
# Permute RoPE weights from interleaved to half-split layout.
for name in state_dict:
if ('attn_q.weight' in name or 'attn_q_b.weight' in name) and (arch == 'llama' or kv_lora_rank):
w = state_dict[name].reshape(n_heads, state_dict[name].shape[0]//n_heads, -1)
prefix = head_dim-rope_dim
state_dict[name] = w[:, :prefix].cat(w[:, prefix:].rearrange("n (h two) d -> n (two h) d", two=2), dim=1).reshape(-1, w.shape[-1])
elif arch == 'llama' and 'attn_k.weight' in name:
w = state_dict[name].reshape(n_kv_heads, state_dict[name].shape[0]//n_kv_heads, -1)
state_dict[name] = w.rearrange("n (h two) d -> n (two h) d", two=2).reshape(-1, w.shape[-1])
elif kv_lora_rank and 'attn_kv_a_mqa.weight' in name:
state_dict[name] = state_dict[name][:kv_lora_rank].cat(state_dict[name][kv_lora_rank:].rearrange("(h two) d -> (two h) d", two=2), dim=0)
config = TransformerConfig(
num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv.get(f'{arch}.feed_forward_length', 0)),
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
vocab_size=len(kv['tokenizer.ggml.tokens']),
head_dim=head_dim,
rope_theta=kv[f'{arch}.rope.freq_base'],
rope_dim=rope_dim,
v_head_dim=kv.get(f'{arch}.attention.value_length_mla', kv.get(f'{arch}.attention.value_length', head_dim)),
max_context=max_context,
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0),
norm_topk_prob=kv.get(f'{arch}.expert_weights_norm', arch in ('qwen3moe', 'qwen35moe')),
kv_lora_rank=kv_lora_rank, q_lora_rank=kv.get(f'{arch}.attention.q_lora_rank', 0),
leading_dense_blocks=kv.get(f'{arch}.leading_dense_block_count', 0),
shared_expert_dim=kv.get(
f'{arch}.expert_shared_feed_forward_length',
kv.get(f'{arch}.expert_shared_count', 0) * kv.get(f'{arch}.expert_feed_forward_length', 0)),
shared_expert_gate=f"blk.{kv.get(f'{arch}.leading_dense_block_count', 0)}.ffn_gate_inp_shexp.weight" in state_dict,
dense_hidden_dim=kv.get(f'{arch}.feed_forward_length', 0) if kv.get(f'{arch}.leading_dense_block_count', 0) else 0,
routed_scaling_factor=kv.get(f'{arch}.expert_weights_scale', 1.0), attn_output_gate=arch in ('qwen35', 'qwen35moe'), ssm=ssm,
full_attention_interval=kv.get(f'{arch}.full_attention_interval', 0))
model = Transformer(config)
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
if realize:
for s in (params:=nn.state.get_parameters(model)): s.replace(s.contiguous())
Tensor.realize(*params)
return model, kv
def get_start_pos(self, tokens:list[int]) -> int:
prefix_len = sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens[:-1], self._cached_tokens)))
return min(block._reusable_prefix_len(prefix_len, len(self._cached_tokens)) for block in self.blk)
def generate(self, tokens:list[int], chunk_size:int=32, temperature:float=0.0):
if self.has_recurrent_block: chunk_size = 1
v_start_pos = UOp.variable("start_pos", 0, self.max_context-1)
v_toks = UOp.variable("toks", 1, chunk_size)
# TODO: use UOp.variable for temperature once float variables are supported
temp = Tensor(temperature).contiguous()
# assign all input tokens once, then slice from start_pos for the model call
t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context)
# recompute start_pos from what's currently valid in the caches
start_pos = self.get_start_pos(tokens)
if start_pos < len(self._cached_tokens) and (resets := [r for b in self.blk for r in b._state_reset_ops()]): Tensor.realize(*resets)
out, prompt_len = None, len(tokens)
while len(tokens) < self.max_context:
sp, nt = v_start_pos.bind(start_pos), v_toks.bind(min(chunk_size, len(tokens) - start_pos))
out = self(t[:, sp:sp+nt] if start_pos < prompt_len or out is None else out, sp, temp).realize()
start_pos += nt.val
# chunked prefill: keep processing until all prompt tokens are consumed
if start_pos < len(tokens): continue
tokens.append(int(out.item()))
self._cached_tokens = tokens[:-1]
yield tokens[-1]