mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
refactor llm into files (#15780)
* refactor llm into files * chat.html * tokenizer cleanup * cleanup * tests
This commit is contained in:
parent
1fac03ce54
commit
a9b6cfece0
10 changed files with 493 additions and 501 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, dtypes, TinyJit, UOp
|
||||
from tinygrad.llm.cli import apply_rope as apply_rope_new, precompute_freqs_cis
|
||||
from tinygrad.llm.model import apply_rope as apply_rope_new, precompute_freqs_cis
|
||||
from test.helpers import assert_jit_cache_len
|
||||
|
||||
def apply_rope(x:Tensor, start_pos:int):
|
||||
|
|
|
|||
|
|
@ -14,23 +14,18 @@ class TestLLMServer(unittest.TestCase):
|
|||
cls.mock_tok.end_turn = Mock(return_value=[998])
|
||||
cls.mock_tok.prefix = Mock(return_value=[1])
|
||||
cls.mock_tok.preset = "llama3"
|
||||
cls.mock_tok.bos_id = 1
|
||||
cls.mock_tok.eos_id = 999
|
||||
cls.mock_tok.eot_id = None
|
||||
cls.mock_tok.is_end = Mock(side_effect=lambda tid: tid in (999,))
|
||||
|
||||
cls.mock_model = Mock()
|
||||
cls.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 999]))
|
||||
cls.mock_model.get_start_pos = Mock(return_value=0)
|
||||
|
||||
cls.bos_id = 1
|
||||
cls.eos_id = 999
|
||||
from tinygrad.llm.cli import LLMServer
|
||||
|
||||
from tinygrad.llm.cli import Handler, LLMServer
|
||||
|
||||
cls.server = LLMServer(('127.0.0.1', 0), Handler)
|
||||
cls.server.model = cls.mock_model
|
||||
cls.server.model_name = "test-model"
|
||||
cls.server.tok = cls.mock_tok
|
||||
cls.server.bos_id = cls.bos_id
|
||||
cls.server.eos_id = cls.eos_id
|
||||
cls.server.eot_id = None
|
||||
cls.server = LLMServer(('127.0.0.1', 0), cls.mock_model, "test-model", cls.mock_tok)
|
||||
cls.port = cls.server.server_address[1]
|
||||
cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True)
|
||||
cls.server_thread.start()
|
||||
|
|
|
|||
|
|
@ -51,11 +51,12 @@ class TestLLMTokenizer(unittest.TestCase):
|
|||
"tokenizer.ggml.tokens": ["<unk>", "<s>", "</s>", "[INST]", "[/INST]", "hello"],
|
||||
"tokenizer.ggml.token_type": [3, 3, 3, 3, 3, 1],
|
||||
"tokenizer.ggml.pre": "tekken",
|
||||
"tokenizer.ggml.eos_token_id": 2,
|
||||
}
|
||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||
self.assertEqual(tok.role("user"), [3])
|
||||
self.assertEqual(tok.encode("hello"), [5])
|
||||
self.assertEqual(tok.end_turn(2), [4])
|
||||
self.assertEqual(tok.end_turn(), [4])
|
||||
self.assertEqual(tok.role("assistant"), [])
|
||||
|
||||
def test_stream_decoder(self):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.llm.cli import (
|
||||
from tinygrad.llm.model import (
|
||||
GatedDeltaNetBlock, SSMConfig, TransformerBlock, TransformerConfig,
|
||||
apply_rope as apply_rope_new, precompute_freqs_cis, pairwise_topk,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.llm.cli import Transformer, TransformerConfig, apply_rope
|
||||
from tinygrad.llm.model import Transformer, TransformerConfig, apply_rope, MLATransformerBlock, precompute_freqs_cis
|
||||
|
||||
class TestMLA(unittest.TestCase):
|
||||
def _make_config(self, **kwargs):
|
||||
|
|
@ -13,7 +13,6 @@ class TestMLA(unittest.TestCase):
|
|||
|
||||
def test_mla_attention_matches_naive(self):
|
||||
config = self._make_config(max_context=16)
|
||||
from tinygrad.llm.cli import MLATransformerBlock, precompute_freqs_cis
|
||||
|
||||
block = MLATransformerBlock(config)
|
||||
c = config
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import unittest
|
|||
import numpy as np
|
||||
from dataclasses import replace
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.llm.cli import TransformerBlock, TransformerConfig
|
||||
from tinygrad.llm.model import TransformerBlock, TransformerConfig
|
||||
|
||||
def _moe_config(dim=8, hidden=16, n_heads=2, num_experts=4, num_experts_per_tok=2):
|
||||
return TransformerConfig(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import unittest
|
|||
from unittest.mock import patch
|
||||
from tinygrad import Tensor, UOp
|
||||
from tinygrad.schedule import schedule_cache
|
||||
from tinygrad.llm.cli import Transformer, TransformerConfig
|
||||
from tinygrad.llm.model import Transformer, TransformerConfig
|
||||
|
||||
TEST_CONFIG = TransformerConfig(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, rope_dim=32, v_head_dim=32, max_context=32)
|
||||
|
|
|
|||
38
tinygrad/llm/chat.html
Normal file
38
tinygrad/llm/chat.html
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
<!DOCTYPE html><html><head><title>tinygrad chat</title><style>
|
||||
* { margin: 0 }
|
||||
body { background: #212121; color: #e3e3e3; font-family: system-ui;
|
||||
height: 100vh; display: flex; flex-direction: column }
|
||||
#chat { flex: 1; overflow-y: auto; padding: 20px }
|
||||
.msg { padding: 10px 16px; margin: 8px 0; white-space: pre-wrap; border-radius: 18px }
|
||||
.user { background: #2f2f2f; margin-left: auto; width: fit-content; max-width: 70% }
|
||||
#input { max-width: 768px; width: 100%; margin: 20px auto; padding: 14px 20px;
|
||||
background: #2f2f2f; color: inherit; font: inherit;
|
||||
border: none; outline: none; resize: none; border-radius: 24px; field-sizing: content }
|
||||
</style></head><body><div id="chat"></div>
|
||||
<textarea id="input" rows="1" placeholder="Ask anything" autofocus></textarea>
|
||||
<script>
|
||||
input.onkeydown = (e) => { if (e.key === 'Enter' && !e.shiftKey && !e.isComposing) { e.preventDefault(); send() } }
|
||||
const msgs = [];
|
||||
async function send() {
|
||||
if (!input.value.trim()) return;
|
||||
msgs.push({role: 'user', content: input.value.trim()});
|
||||
chat.innerHTML += '<div class="msg user">' + input.value.trim().replace(/</g, '<') + '</div>';
|
||||
input.value = '';
|
||||
const d = document.createElement('div'); d.className = 'msg'; chat.appendChild(d);
|
||||
const r = await fetch('/v1/chat/completions', {method: 'POST', headers: {'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({model: 'llama', messages: msgs, stream: true, temperature: 0.7})});
|
||||
let buf = '';
|
||||
for (const rd = r.body.getReader(), dec = new TextDecoder();;) {
|
||||
const {done, value} = await rd.read();
|
||||
if (done) break;
|
||||
buf += dec.decode(value, {stream: true});
|
||||
const lines = buf.split('\n');
|
||||
buf = lines.pop();
|
||||
for (const ln of lines)
|
||||
if (ln.startsWith('data: ') && !ln.includes('[DONE]'))
|
||||
try { d.textContent += JSON.parse(ln.slice(6)).choices[0]?.delta?.content || '' } catch {}
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
msgs.push({role: 'assistant', content: d.textContent});
|
||||
}
|
||||
</script></body></html>
|
||||
|
|
@ -1,13 +1,13 @@
|
|||
from __future__ import annotations
|
||||
import sys, argparse, codecs, typing, re, unicodedata, json, uuid, time, functools, itertools
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function
|
||||
from tinygrad.uop.ops import resolve
|
||||
import sys, argparse, codecs, typing, re, unicodedata, json, uuid, time, pathlib
|
||||
from tinygrad import Tensor, nn
|
||||
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored, Context
|
||||
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
|
||||
from tinygrad.llm.model import Transformer
|
||||
|
||||
class SimpleTokenizer:
|
||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"):
|
||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3",
|
||||
bos_id:int|None=None, eos_id:int=0, eot_id:int|None=None):
|
||||
preset = {"qwen35":"qwen2","qwen35moe":"qwen2"}.get(preset, preset)
|
||||
if preset not in ("llama3","llama-v3","llama-bpe","qwen2","olmo","kimi-k2","tekken","glm4"):
|
||||
raise ValueError(f"Invalid tokenizer preset '{preset}'")
|
||||
|
|
@ -27,13 +27,16 @@ class SimpleTokenizer:
|
|||
self._special_tokens = special_tokens
|
||||
self._tok2bytes = {tid: tok for tok, tid in self._normal_tokens.items()} | {tid: tok.encode() for tok, tid in self._special_tokens.items()}
|
||||
self.preset = preset
|
||||
self.bos_id, self.eos_id, self.eot_id = bos_id, eos_id, eot_id
|
||||
|
||||
@staticmethod
|
||||
def from_gguf_kv(kv:dict):
|
||||
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
|
||||
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
|
||||
normal_tokens, special_tokens = partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
|
||||
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens), kv["tokenizer.ggml.pre"])
|
||||
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens), kv["tokenizer.ggml.pre"],
|
||||
bos_id=kv.get('tokenizer.ggml.bos_token_id') if kv.get('tokenizer.ggml.add_bos_token', True) else None,
|
||||
eos_id=kv.get('tokenizer.ggml.eos_token_id', 0), eot_id=kv.get('tokenizer.ggml.eot_token_id'))
|
||||
|
||||
def _encode_word(self, word:bytes) -> list[int]:
|
||||
if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token]
|
||||
|
|
@ -70,419 +73,16 @@ class SimpleTokenizer:
|
|||
if role == 'assistant': return []
|
||||
raise ValueError(f"Unsupported role '{role}' for tokenizer preset '{self.preset}'")
|
||||
return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
||||
def end_turn(self, eos_id:int):
|
||||
def end_turn(self):
|
||||
if self.preset == 'olmo': return self.encode("\n")
|
||||
if self.preset == 'kimi-k2': return [eos_id]
|
||||
if self.preset == 'qwen2': return [eos_id] + self.encode("\n")
|
||||
if self.preset == 'kimi-k2': return [self.eos_id]
|
||||
if self.preset == 'qwen2': return [self.eos_id] + self.encode("\n")
|
||||
if self.preset == 'glm4': return []
|
||||
if self.preset == 'tekken': return self.encode("[/INST]")
|
||||
return [eos_id]
|
||||
def prefix(self, bos_id:int|None) -> list[int]:
|
||||
return ([] if bos_id is None else [bos_id]) + (self.encode("<sop>") if self.preset == 'glm4' else [])
|
||||
|
||||
@functools.cache
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
||||
return freqs.cos().cat(freqs.sin(), dim=-1).contiguous()
|
||||
|
||||
class ExpertWeights:
|
||||
"""Like nn.Linear but with num_experts dimension. Weight shape: (num_experts, out_features, in_features)."""
|
||||
def __init__(self, num_experts:int, in_features:int, out_features:int):
|
||||
self.weight = Tensor.zeros(num_experts, out_features, in_features)
|
||||
def __call__(self, sel:Tensor, x:Tensor) -> Tensor:
|
||||
# sel: (B, T, k), x: (B, T, 1, in) or (B, T, k, in) -> output: (B, T, k, out)
|
||||
return (x.unsqueeze(-2) @ self.weight[sel].transpose(-1, -2)).squeeze(-2)
|
||||
|
||||
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
||||
assert x.shape[-1] % 2 == 0
|
||||
cos, sin = freqs_cis.reshape(1, 1, x.shape[2], -1).chunk(2, dim=-1)
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return (x1 * cos - x2 * sin).cat(x2 * cos + x1 * sin, dim=-1)
|
||||
|
||||
def pairwise_topk(x: Tensor, k: int) -> tuple[Tensor, Tensor]:
|
||||
n = x.shape[-1]
|
||||
vals = Tensor.arange(n).reshape(1,1,n).cast(x.dtype).expand(x.shape)
|
||||
cmp = (x.unsqueeze(-1) > x.unsqueeze(-2)) | ((x.unsqueeze(-1) == x.unsqueeze(-2)) & \
|
||||
(Tensor.arange(n).reshape(1,1,n,1) < Tensor.arange(n).reshape(1,1,1,n)))
|
||||
sel = Tensor.zeros_like(x).scatter(-1, cmp.sum(axis=-1).cast('int32'), vals)[:,:,n-k:].cast('int32')
|
||||
return x.gather(-1, sel), sel
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SSMConfig:
|
||||
conv_kernel: int
|
||||
state_size: int
|
||||
group_count: int
|
||||
time_step_rank: int
|
||||
inner_size: int
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TransformerConfig:
|
||||
num_blocks: int
|
||||
dim: int
|
||||
hidden_dim: int
|
||||
n_heads: int
|
||||
n_kv_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
head_dim: int
|
||||
rope_theta: float
|
||||
rope_dim: int
|
||||
v_head_dim: int
|
||||
max_context: int = 0
|
||||
qk_norm: int = 0
|
||||
num_experts: int = 0
|
||||
num_experts_per_tok: int = 0
|
||||
norm_topk_prob: bool = False
|
||||
q_lora_rank: int = 0
|
||||
kv_lora_rank: int = 0
|
||||
shared_expert_dim: int = 0
|
||||
full_attention_interval: int = 0
|
||||
attn_output_gate: bool = False
|
||||
ssm: SSMConfig|None = None
|
||||
shared_expert_gate: bool = True
|
||||
leading_dense_blocks: int = 0
|
||||
dense_hidden_dim: int = 0
|
||||
routed_scaling_factor: float = 1.0
|
||||
|
||||
class FFNBlock:
|
||||
def __init__(self, config:TransformerConfig):
|
||||
self.config = config
|
||||
|
||||
# --- RMSNorms --------------------------------------------------------
|
||||
self.attn_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
|
||||
# --- feed-forward (MoE or dense) -------------------------------------
|
||||
if config.num_experts > 0:
|
||||
self.ffn_gate_inp = nn.Linear(config.dim, config.num_experts, bias=False) # router
|
||||
if config.kv_lora_rank > 0: self.exp_probs_b = {"bias": Tensor.zeros(config.num_experts)}
|
||||
self.ffn_gate_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
|
||||
self.ffn_up_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
|
||||
self.ffn_down_exps = ExpertWeights(config.num_experts, config.hidden_dim, config.dim)
|
||||
if config.shared_expert_dim > 0:
|
||||
self.ffn_gate_shexp = nn.Linear(config.dim, config.shared_expert_dim, bias=False)
|
||||
self.ffn_up_shexp = nn.Linear(config.dim, config.shared_expert_dim, bias=False)
|
||||
self.ffn_down_shexp = nn.Linear(config.shared_expert_dim, config.dim, bias=False)
|
||||
if config.shared_expert_gate: self.ffn_gate_inp_shexp = {"weight": Tensor.zeros(config.dim)}
|
||||
else:
|
||||
self.ffn_gate = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.ffn_up = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.ffn_down = nn.Linear(config.hidden_dim, config.dim, bias=False)
|
||||
|
||||
def _feed_forward(self, x:Tensor) -> Tensor:
|
||||
if hasattr(self, 'ffn_gate_exps'):
|
||||
h = x.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting
|
||||
logits = self.ffn_gate_inp(x)
|
||||
if hasattr(self, 'exp_probs_b'):
|
||||
probs = logits.sigmoid()
|
||||
_, sel = pairwise_topk(probs + self.exp_probs_b["bias"], self.config.num_experts_per_tok)
|
||||
probs = probs.gather(-1, sel)
|
||||
if self.config.norm_topk_prob: probs = probs / probs.sum(axis=-1, keepdim=True)
|
||||
else:
|
||||
vals, sel = pairwise_topk(logits, self.config.num_experts_per_tok)
|
||||
probs = vals.softmax(-1) if self.config.norm_topk_prob else logits.softmax(-1).gather(-1, sel)
|
||||
probs = probs * self.config.routed_scaling_factor
|
||||
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, h).silu() * self.ffn_up_exps(sel, h)) # (B, T, k, D)
|
||||
out = (x_down * probs.unsqueeze(-1)).sum(axis=2) # (B, T, D)
|
||||
if hasattr(self, 'ffn_gate_shexp'):
|
||||
shexp = self.ffn_down_shexp(self.ffn_gate_shexp(x).silu().contiguous() * self.ffn_up_shexp(x))
|
||||
if hasattr(self, 'ffn_gate_inp_shexp'): shexp = shexp * (x * self.ffn_gate_inp_shexp["weight"]).sum(axis=-1, keepdim=True).sigmoid()
|
||||
out = out + shexp
|
||||
return out
|
||||
# TODO: remove the need for this contiguous
|
||||
return self.ffn_down(self.ffn_gate(x).silu().contiguous() * self.ffn_up(x))
|
||||
|
||||
# given the token-prefix match, return how much cached state this block can still reuse
|
||||
def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return prefix_len
|
||||
# return writes that reset this block's state after a cache mismatch
|
||||
def _state_reset_ops(self) -> list[Tensor]: return []
|
||||
def _init_state(self, x:Tensor): raise NotImplementedError
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor: raise NotImplementedError
|
||||
|
||||
def __call__(self, x: Tensor, start_pos: int|UOp):
|
||||
self._init_state(x)
|
||||
# we pass in the weights implicitly so we unpack the GGUF on the fly
|
||||
@function(precompile=True, allow_implicit=True)
|
||||
def _run(x:Tensor, start_pos:int|UOp):
|
||||
h = x + self._attention(self.attn_norm(x), start_pos)
|
||||
return (h + self._feed_forward(self.ffn_norm(h))).contiguous()
|
||||
return _run(x, start_pos)
|
||||
|
||||
class TransformerBlock(FFNBlock):
|
||||
def __init__(self, config:TransformerConfig):
|
||||
super().__init__(config)
|
||||
assert config.v_head_dim == config.head_dim, "TransformerBlock requires v_head_dim == head_dim"
|
||||
|
||||
# --- attention projections (all linear, bias-free) ------------------
|
||||
q_proj_out = config.head_dim * config.n_heads * (2 if config.attn_output_gate else 1)
|
||||
kv_proj_out = config.head_dim * config.n_kv_heads
|
||||
self.attn_q = nn.Linear(config.dim, q_proj_out, bias=False)
|
||||
self.attn_k = nn.Linear(config.dim, kv_proj_out, bias=False)
|
||||
self.attn_v = nn.Linear(config.dim, kv_proj_out, bias=False)
|
||||
self.attn_output = nn.Linear(config.head_dim * config.n_heads, config.dim, bias=False)
|
||||
if config.qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(config.qk_norm, config.norm_eps), nn.RMSNorm(config.qk_norm, config.norm_eps)
|
||||
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
q, k, v = self.attn_q(x), self.attn_k(x), self.attn_v(x)
|
||||
if self.config.qk_norm and self.config.qk_norm != self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
|
||||
B, T, _ = x.shape
|
||||
if self.config.attn_output_gate:
|
||||
qg = q.reshape(B, T, self.config.n_heads, 2, self.config.head_dim)
|
||||
q, gate = qg[:, :, :, 0, :], qg[:, :, :, 1, :].reshape(B, T, self.config.n_heads * self.config.head_dim)
|
||||
q = q.reshape(B, T, self.config.n_heads, self.config.head_dim).transpose(1, 2) # (B,H,T,Hd)
|
||||
k = k.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
v = v.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
if self.config.qk_norm == self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
|
||||
q = apply_rope(q[..., :self.config.rope_dim], self.freqs_cis[start_pos:start_pos+T]).cat(q[..., self.config.rope_dim:], dim=-1)
|
||||
k = apply_rope(k[..., :self.config.rope_dim], self.freqs_cis[start_pos:start_pos+T]).cat(k[..., self.config.rope_dim:], dim=-1)
|
||||
|
||||
# NOTE: we don't want to change self.cache_kv, the function API doesn't support this well
|
||||
assigned_kv = Tensor(self.cache_kv.uop.after(self.cache_kv[:, :, :, start_pos:start_pos+T, :].uop.store(Tensor.stack(k, v).uop)))
|
||||
k = assigned_kv[0, :, :, 0:start_pos+T, :]
|
||||
v = assigned_kv[1, :, :, 0:start_pos+T, :]
|
||||
|
||||
#self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
|
||||
#k = self.cache_kv[0, :, :, 0:start_pos+T, :]
|
||||
#v = self.cache_kv[1, :, :, 0:start_pos+T, :]
|
||||
|
||||
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
||||
# TODO: this if statement should be removed and it shouldn't generate extra kernels
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
|
||||
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
|
||||
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
|
||||
return self.attn_output(attn if not self.config.attn_output_gate else (attn * gate.sigmoid()))
|
||||
|
||||
def _init_state(self, x:Tensor):
|
||||
if not hasattr(self, "cache_kv"):
|
||||
# TODO: how is the dtype of this determined?
|
||||
self.cache_kv = Tensor.empty(2, x.shape[0], self.config.n_kv_heads, self.config.max_context, self.config.head_dim, device=x.device)
|
||||
self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta)
|
||||
|
||||
class MLATransformerBlock(FFNBlock):
|
||||
def __init__(self, config:TransformerConfig):
|
||||
super().__init__(config)
|
||||
qk_nope_head_dim = config.head_dim - config.rope_dim
|
||||
if config.q_lora_rank > 0:
|
||||
self.attn_q_a = nn.Linear(config.dim, config.q_lora_rank, bias=False)
|
||||
self.attn_q_a_norm = nn.RMSNorm(config.q_lora_rank, config.norm_eps)
|
||||
self.attn_q_b = nn.Linear(config.q_lora_rank, config.n_heads * config.head_dim, bias=False)
|
||||
else:
|
||||
self.attn_q = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False)
|
||||
self.attn_kv_a_mqa = nn.Linear(config.dim, config.kv_lora_rank + config.rope_dim, bias=False)
|
||||
self.attn_kv_a_norm = nn.RMSNorm(config.kv_lora_rank, config.norm_eps)
|
||||
self.attn_k_b = {"weight": Tensor.zeros(config.n_heads, config.kv_lora_rank, qk_nope_head_dim)}
|
||||
self.attn_v_b = {"weight": Tensor.zeros(config.n_heads, config.v_head_dim, config.kv_lora_rank)}
|
||||
self.attn_output = nn.Linear(config.n_heads * config.v_head_dim, config.dim, bias=False)
|
||||
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
B, T, _ = x.shape
|
||||
q_nope_head_dim = self.config.head_dim - self.config.rope_dim
|
||||
q_proj = self.attn_q_b(self.attn_q_a_norm(self.attn_q_a(x))) if self.config.q_lora_rank > 0 else self.attn_q(x)
|
||||
q = q_proj.reshape(B, T, self.config.n_heads, self.config.head_dim).transpose(1, 2)
|
||||
q_nope, q_rope = q[..., :q_nope_head_dim], q[..., q_nope_head_dim:]
|
||||
q = (q_nope @ self.attn_k_b["weight"].transpose(-1, -2)).cat(apply_rope(q_rope, self.freqs_cis[start_pos:start_pos+T]), dim=-1)
|
||||
|
||||
kv_a = self.attn_kv_a_mqa(x)
|
||||
c_kv = self.attn_kv_a_norm(kv_a[..., :self.config.kv_lora_rank])
|
||||
k_rope = apply_rope(
|
||||
kv_a[..., self.config.kv_lora_rank:].reshape(B, T, 1, self.config.rope_dim).transpose(1, 2),
|
||||
self.freqs_cis[start_pos:start_pos+T])
|
||||
|
||||
k_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank).cat(k_rope.reshape(B, 1, T, self.config.rope_dim), dim=-1)
|
||||
v_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank)
|
||||
k = Tensor(self.cache_k.uop.after(self.cache_k[:, :, start_pos:start_pos+T, :].uop.store(k_store.uop)))[:, :, 0:start_pos+T, :]
|
||||
v = Tensor(self.cache_v.uop.after(self.cache_v[:, :, start_pos:start_pos+T, :].uop.store(v_store.uop)))[:, :, 0:start_pos+T, :]
|
||||
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
|
||||
attn = q @ k.transpose(-1, -2) * (1.0 / self.config.head_dim ** 0.5)
|
||||
if mask is not None: attn = attn + mask
|
||||
attn = attn.softmax(-1)
|
||||
attn = ((attn @ v) @ self.attn_v_b["weight"].transpose(-1, -2)).transpose(1, 2).reshape(B, T, -1)
|
||||
return self.attn_output(attn)
|
||||
|
||||
def _init_state(self, x:Tensor):
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank + self.config.rope_dim, device=x.device)
|
||||
self.cache_v = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank, device=x.device)
|
||||
self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta)
|
||||
|
||||
class GatedDeltaNetBlock(FFNBlock):
|
||||
def __init__(self, config:TransformerConfig, ssm:SSMConfig):
|
||||
super().__init__(config)
|
||||
self.head_k_dim, self.num_k_heads, self.num_v_heads = ssm.state_size, ssm.group_count, ssm.time_step_rank
|
||||
assert self.num_v_heads % self.num_k_heads == 0
|
||||
self.head_v_dim, self.ssm_conv_kernel = ssm.inner_size // ssm.time_step_rank, ssm.conv_kernel
|
||||
self.conv_channels, self.q_dim = ssm.inner_size + 2*ssm.group_count*ssm.state_size, ssm.state_size*ssm.group_count
|
||||
self.attn_qkv, self.attn_gate = nn.Linear(config.dim, self.conv_channels, bias=False), nn.Linear(config.dim, ssm.inner_size, bias=False)
|
||||
self.ssm_alpha, self.ssm_beta = nn.Linear(config.dim, self.num_v_heads, bias=False), nn.Linear(config.dim, self.num_v_heads, bias=False)
|
||||
self.ssm_conv1d = {"weight": Tensor.zeros(self.conv_channels, self.ssm_conv_kernel)}
|
||||
self.ssm_dt = {"bias": Tensor.zeros(self.num_v_heads)}
|
||||
self.ssm_a = Tensor.zeros(self.num_v_heads)
|
||||
self.ssm_norm, self.ssm_out = nn.RMSNorm(self.head_v_dim, config.norm_eps), nn.Linear(ssm.inner_size, config.dim, bias=False)
|
||||
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
B, T, _ = x.shape
|
||||
assert T == 1, "GatedDeltaNetBlock currently only supports T=1"
|
||||
|
||||
# input processing
|
||||
x = x.half()
|
||||
out_gate = self.attn_gate(x).reshape(B, 1, self.num_v_heads, self.head_v_dim)
|
||||
beta = self.ssm_beta(x).sigmoid().reshape(B, self.num_v_heads, 1, 1)
|
||||
alpha = ((self.ssm_alpha(x).float() + self.ssm_dt["bias"]).softplus() * self.ssm_a).reshape(B, self.num_v_heads, 1, 1).exp()
|
||||
|
||||
# qkv conv
|
||||
conv_window = self.conv_state.cat(self.attn_qkv(x), dim=1)
|
||||
conv_out = (conv_window * self.ssm_conv1d["weight"].T.unsqueeze(0)).sum(1).silu()
|
||||
q, k, v = conv_out.split([self.q_dim, self.q_dim, self.conv_channels - 2*self.q_dim], dim=-1)
|
||||
q = q.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1).repeat(1, self.num_v_heads//self.num_k_heads, 1)
|
||||
k = k.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1).repeat(1, self.num_v_heads//self.num_k_heads, 1)
|
||||
v = v.reshape(B, self.num_v_heads, self.head_v_dim)
|
||||
q, k, v = q.mul(self.head_k_dim**-0.5).unsqueeze(-1), k.unsqueeze(-1), v.unsqueeze(-1)
|
||||
|
||||
# recurrent
|
||||
recurrent_state = self.recurrent_state * alpha
|
||||
recurrent_state = recurrent_state + ((v - recurrent_state@k) * beta)@k.transpose(-1, -2)
|
||||
|
||||
# store the updated state
|
||||
conv_state_store = self.conv_state.uop.store(conv_window[:, 1:, :].cast(self.conv_state.dtype).uop)
|
||||
recurrent_state_store = self.recurrent_state.uop.store(recurrent_state.cast(self.recurrent_state.dtype).uop)
|
||||
recurrent_state = Tensor(self.recurrent_state.uop.after(recurrent_state_store, conv_state_store))
|
||||
|
||||
# output
|
||||
core_attn_out = self.ssm_norm((recurrent_state@q).squeeze(-1).reshape(B, 1, self.num_v_heads, self.head_v_dim))
|
||||
return self.ssm_out((core_attn_out * out_gate.silu()).reshape(B, 1, -1).cast(x.dtype))
|
||||
|
||||
# recurrent state can't be partially reused after divergence, force a full rebuild
|
||||
def _state_reset_ops(self):
|
||||
return [self.conv_state.assign(Tensor.zeros_like(self.conv_state)),
|
||||
self.recurrent_state.assign(Tensor.zeros_like(self.recurrent_state))] if hasattr(self, "conv_state") else []
|
||||
def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return 0 if prefix_len != cached_len else prefix_len
|
||||
|
||||
def _init_state(self, x):
|
||||
if not hasattr(self, "conv_state"):
|
||||
self.conv_state = Tensor.zeros(x.shape[0], self.ssm_conv_kernel-1, self.conv_channels, device=x.device).clone()
|
||||
self.recurrent_state = Tensor.zeros(x.shape[0], self.num_v_heads, self.head_v_dim, self.head_v_dim, device=x.device).clone()
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, config:TransformerConfig):
|
||||
dense_config = replace(config, num_experts=0, num_experts_per_tok=0, shared_expert_dim=0, hidden_dim=config.dense_hidden_dim or config.hidden_dim)
|
||||
if config.ssm: config = replace(config, qk_norm=config.head_dim)
|
||||
block_cls = MLATransformerBlock if config.kv_lora_rank > 0 else TransformerBlock
|
||||
self.blk:list[FFNBlock] = [GatedDeltaNetBlock(config, config.ssm) if config.ssm and (i+1) % config.full_attention_interval != 0 else
|
||||
block_cls(dense_config if i < config.leading_dense_blocks else config) for i in range(config.num_blocks)]
|
||||
self.token_embd = nn.Embedding(config.vocab_size, config.dim)
|
||||
self.output_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
|
||||
self.max_context = config.max_context
|
||||
self.has_recurrent_block = any(isinstance(b, GatedDeltaNetBlock) for b in self.blk)
|
||||
self._cached_tokens: list[int] = []
|
||||
# we specialize the JIT for prefill and rollout
|
||||
self.prefill_jit = TinyJit(self.forward)
|
||||
self.rollout_jit = TinyJit(self.forward)
|
||||
|
||||
def forward(self, tokens:Tensor, start_pos:int|UOp, temperature:Tensor) -> Tensor:
|
||||
x = self.token_embd(tokens).float() # (B, T, D)
|
||||
for block in self.blk: x = block(x, start_pos)
|
||||
logits = self.output(self.output_norm(x))[:, -1, :]
|
||||
# Gumbel-max trick: argmax(logits/temp - log(-log(uniform))) is equivalent to sampling from softmax(logits/temp)
|
||||
return (logits / temperature.maximum(1e-12) - (Tensor.rand_like(logits).maximum(1e-12).log().neg()).log()).argmax(-1, keepdim=True)
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int|UOp, temperature:Tensor) -> Tensor:
|
||||
return (self.prefill_jit if resolve(tokens.shape[1] != 1) else self.rollout_jit)(tokens.contiguous(), start_pos, temperature)
|
||||
|
||||
@staticmethod
|
||||
def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALIZE", 0))) -> tuple[Transformer, dict]:
|
||||
# TODO: remove the need for copy to default device
|
||||
kv, state_dict = nn.state.gguf_load(gguf.to(None).realize())
|
||||
|
||||
# all state items should be float16, not float32
|
||||
state_dict = {k:v.cast('float16') if getenv("HALF", 1) else v for k,v in state_dict.items()}
|
||||
|
||||
# some models like Llama 3.2 don't have an output.weight, they just tie to the token_embd.weight
|
||||
if 'output.weight' not in state_dict: state_dict['output.weight'] = state_dict['token_embd.weight']
|
||||
|
||||
arch = kv['general.architecture']
|
||||
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
|
||||
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
|
||||
|
||||
ssm = None
|
||||
if arch in ('qwen35', 'qwen35moe'):
|
||||
ssm = SSMConfig(**{k: kv[f'{arch}.ssm.{k}'] for k in ('conv_kernel','state_size','group_count','time_step_rank','inner_size')})
|
||||
state_dict = {k.replace('post_attention_norm', 'ffn_norm'):v for k,v in state_dict.items()}
|
||||
|
||||
kv_lora_rank = kv.get(f'{arch}.attention.kv_lora_rank', 0)
|
||||
head_dim = kv.get(f'{arch}.attention.key_length_mla', kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads))
|
||||
rope_dim = kv.get(f'{arch}.rope.dimension_count', head_dim)
|
||||
|
||||
# Permute RoPE weights from interleaved to half-split layout.
|
||||
for name in state_dict:
|
||||
if ('attn_q.weight' in name or 'attn_q_b.weight' in name) and (arch == 'llama' or kv_lora_rank):
|
||||
w = state_dict[name].reshape(n_heads, state_dict[name].shape[0]//n_heads, -1)
|
||||
prefix = head_dim-rope_dim
|
||||
state_dict[name] = w[:, :prefix].cat(w[:, prefix:].rearrange("n (h two) d -> n (two h) d", two=2), dim=1).reshape(-1, w.shape[-1])
|
||||
elif arch == 'llama' and 'attn_k.weight' in name:
|
||||
w = state_dict[name].reshape(n_kv_heads, state_dict[name].shape[0]//n_kv_heads, -1)
|
||||
state_dict[name] = w.rearrange("n (h two) d -> n (two h) d", two=2).reshape(-1, w.shape[-1])
|
||||
elif kv_lora_rank and 'attn_kv_a_mqa.weight' in name:
|
||||
state_dict[name] = state_dict[name][:kv_lora_rank].cat(state_dict[name][kv_lora_rank:].rearrange("(h two) d -> (two h) d", two=2), dim=0)
|
||||
config = TransformerConfig(
|
||||
num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
|
||||
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv.get(f'{arch}.feed_forward_length', 0)),
|
||||
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
|
||||
vocab_size=len(kv['tokenizer.ggml.tokens']),
|
||||
head_dim=head_dim,
|
||||
rope_theta=kv[f'{arch}.rope.freq_base'],
|
||||
rope_dim=rope_dim,
|
||||
v_head_dim=kv.get(f'{arch}.attention.value_length_mla', kv.get(f'{arch}.attention.value_length', head_dim)),
|
||||
max_context=max_context,
|
||||
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
|
||||
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0),
|
||||
norm_topk_prob=kv.get(f'{arch}.expert_weights_norm', arch in ('qwen3moe', 'qwen35moe')),
|
||||
kv_lora_rank=kv_lora_rank, q_lora_rank=kv.get(f'{arch}.attention.q_lora_rank', 0),
|
||||
leading_dense_blocks=kv.get(f'{arch}.leading_dense_block_count', 0),
|
||||
shared_expert_dim=kv.get(
|
||||
f'{arch}.expert_shared_feed_forward_length',
|
||||
kv.get(f'{arch}.expert_shared_count', 0) * kv.get(f'{arch}.expert_feed_forward_length', 0)),
|
||||
shared_expert_gate=f"blk.{kv.get(f'{arch}.leading_dense_block_count', 0)}.ffn_gate_inp_shexp.weight" in state_dict,
|
||||
dense_hidden_dim=kv.get(f'{arch}.feed_forward_length', 0) if kv.get(f'{arch}.leading_dense_block_count', 0) else 0,
|
||||
routed_scaling_factor=kv.get(f'{arch}.expert_weights_scale', 1.0), attn_output_gate=arch in ('qwen35', 'qwen35moe'), ssm=ssm,
|
||||
full_attention_interval=kv.get(f'{arch}.full_attention_interval', 0))
|
||||
model = Transformer(config)
|
||||
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
||||
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
||||
if realize:
|
||||
for s in (params:=nn.state.get_parameters(model)): s.replace(s.contiguous())
|
||||
Tensor.realize(*params)
|
||||
return model, kv
|
||||
|
||||
def get_start_pos(self, tokens:list[int]) -> int:
|
||||
prefix_len = sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens[:-1], self._cached_tokens)))
|
||||
return min(block._reusable_prefix_len(prefix_len, len(self._cached_tokens)) for block in self.blk)
|
||||
|
||||
def generate(self, tokens:list[int], chunk_size:int=32, temperature:float=0.0):
|
||||
if self.has_recurrent_block: chunk_size = 1
|
||||
v_start_pos = UOp.variable("start_pos", 0, self.max_context-1)
|
||||
v_toks = UOp.variable("toks", 1, chunk_size)
|
||||
# TODO: use UOp.variable for temperature once float variables are supported
|
||||
temp = Tensor(temperature).contiguous()
|
||||
# assign all input tokens once, then slice from start_pos for the model call
|
||||
t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context)
|
||||
# recompute start_pos from what's currently valid in the caches
|
||||
start_pos = self.get_start_pos(tokens)
|
||||
if start_pos < len(self._cached_tokens) and (resets := [r for b in self.blk for r in b._state_reset_ops()]): Tensor.realize(*resets)
|
||||
out, prompt_len = None, len(tokens)
|
||||
while len(tokens) < self.max_context:
|
||||
sp, nt = v_start_pos.bind(start_pos), v_toks.bind(min(chunk_size, len(tokens) - start_pos))
|
||||
out = self(t[:, sp:sp+nt] if start_pos < prompt_len or out is None else out, sp, temp).realize()
|
||||
start_pos += nt.val
|
||||
# chunked prefill: keep processing until all prompt tokens are consumed
|
||||
if start_pos < len(tokens): continue
|
||||
tokens.append(int(out.item()))
|
||||
self._cached_tokens = tokens[:-1]
|
||||
yield tokens[-1]
|
||||
return [self.eos_id]
|
||||
def prefix(self) -> list[int]:
|
||||
return ([] if self.bos_id is None else [self.bos_id]) + (self.encode("<sop>") if self.preset == 'glm4' else [])
|
||||
def is_end(self, token_id:int) -> bool: return token_id in (self.eos_id, self.eot_id)
|
||||
|
||||
models = {
|
||||
"llama3.2:1b": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
|
||||
|
|
@ -506,62 +106,14 @@ models = {
|
|||
|
||||
# *** simple OpenAI API compatible server with web interface on http://localhost:8000/ ***
|
||||
|
||||
CHAT_HTML = b'''<!DOCTYPE html><html><head><title>tinygrad chat</title><style>
|
||||
* { margin: 0 }
|
||||
body { background: #212121; color: #e3e3e3; font-family: system-ui;
|
||||
height: 100vh; display: flex; flex-direction: column }
|
||||
#chat { flex: 1; overflow-y: auto; padding: 20px }
|
||||
.msg { padding: 10px 16px; margin: 8px 0; white-space: pre-wrap; border-radius: 18px }
|
||||
.user { background: #2f2f2f; margin-left: auto; width: fit-content; max-width: 70% }
|
||||
#input { max-width: 768px; width: 100%; margin: 20px auto; padding: 14px 20px;
|
||||
background: #2f2f2f; color: inherit; font: inherit;
|
||||
border: none; outline: none; resize: none; border-radius: 24px; field-sizing: content }
|
||||
</style></head><body><div id="chat"></div>
|
||||
<textarea id="input" rows="1" placeholder="Ask anything" autofocus></textarea>
|
||||
<script>
|
||||
input.onkeydown = (e) => { if (e.key === 'Enter' && !e.shiftKey && !e.isComposing) { e.preventDefault(); send() } }
|
||||
const msgs = [];
|
||||
async function send() {
|
||||
if (!input.value.trim()) return;
|
||||
msgs.push({role: 'user', content: input.value.trim()});
|
||||
chat.innerHTML += '<div class="msg user">' + input.value.trim().replace(/</g, '<') + '</div>';
|
||||
input.value = '';
|
||||
const d = document.createElement('div'); d.className = 'msg'; chat.appendChild(d);
|
||||
const r = await fetch('/v1/chat/completions', {method: 'POST', headers: {'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({model: 'llama', messages: msgs, stream: true, temperature: 0.7})});
|
||||
let buf = '';
|
||||
for (const rd = r.body.getReader(), dec = new TextDecoder();;) {
|
||||
const {done, value} = await rd.read();
|
||||
if (done) break;
|
||||
buf += dec.decode(value, {stream: true});
|
||||
const lines = buf.split('\\n');
|
||||
buf = lines.pop();
|
||||
for (const ln of lines)
|
||||
if (ln.startsWith('data: ') && !ln.includes('[DONE]'))
|
||||
try { d.textContent += JSON.parse(ln.slice(6)).choices[0]?.delta?.content || '' } catch {}
|
||||
chat.scrollTop = chat.scrollHeight;
|
||||
}
|
||||
msgs.push({role: 'assistant', content: d.textContent});
|
||||
}
|
||||
</script></body></html>'''
|
||||
|
||||
class LLMServer(TCPServerWithReuse):
|
||||
model: Transformer
|
||||
model_name: str
|
||||
tok: SimpleTokenizer
|
||||
# TODO: tastefully move these into tokenizer
|
||||
bos_id: int|None
|
||||
eos_id: int
|
||||
eot_id: int|None
|
||||
|
||||
class Handler(HTTPRequestHandler):
|
||||
server: LLMServer
|
||||
def log_request(self, code='-', size='-'): pass
|
||||
def do_GET(self):
|
||||
if self.path == "/v1/models": self.send_data(json.dumps({"object":"list","data":[{"id":self.server.model_name,"object":"model"}]}).encode())
|
||||
else: self.send_data(CHAT_HTML, content_type="text/html")
|
||||
else: self.send_data((pathlib.Path(__file__).parent / "chat.html").read_bytes(), content_type="text/html")
|
||||
def run_model(self, ids:list[int], model_name:str, include_usage=False, max_tokens:int|None=None, temperature:float=0.0):
|
||||
model, tok, eos_id, eot_id = self.server.model, self.server.tok, self.server.eos_id, self.server.eot_id
|
||||
model, tok = self.server.model, self.server.tok
|
||||
cache_start_pos = model.get_start_pos(ids)
|
||||
stderr_log(f"{self.path} {colored('--', 'BLACK')} "
|
||||
f"in:{colored(f'{cache_start_pos:5d}', 'green')} +{len(ids)-cache_start_pos:5d} {colored('--', 'BLACK')} ")
|
||||
|
|
@ -573,7 +125,7 @@ class Handler(HTTPRequestHandler):
|
|||
dec = tok.stream_decoder()
|
||||
for next_id in model.generate(ids, temperature=temperature):
|
||||
if len(out) == 0: stderr_log(f"prefill:{(len(ids)-cache_start_pos)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ")
|
||||
if next_id in (eos_id, eot_id): break
|
||||
if tok.is_end(next_id): break
|
||||
out.append(next_id)
|
||||
yield {"choices": [{"index":0, "delta":{"content":dec(next_id)}, "finish_reason":None}], **tmpl}
|
||||
if max_tokens is not None and len(out) >= max_tokens:
|
||||
|
|
@ -588,13 +140,13 @@ class Handler(HTTPRequestHandler):
|
|||
f"out:{len(out):5d} {colored('--', 'BLACK')} total:{et-st:6.2f}s\n")
|
||||
|
||||
def do_POST(self):
|
||||
tok, bos_id, eos_id = self.server.tok, self.server.bos_id, self.server.eos_id
|
||||
tok = self.server.tok
|
||||
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
|
||||
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
|
||||
if DEBUG >= 1: print(json.dumps(body, indent=2))
|
||||
if self.path == "/v1/chat/completions":
|
||||
# extract tokens, last assistant message is treated as prefill
|
||||
ids: list[int] = tok.prefix(bos_id)
|
||||
ids: list[int] = tok.prefix()
|
||||
for i, msg in enumerate(body["messages"]):
|
||||
ids += tok.role(msg["role"])
|
||||
content = msg["content"]
|
||||
|
|
@ -605,7 +157,7 @@ class Handler(HTTPRequestHandler):
|
|||
else: raise RuntimeError(f"unhandled type: {c['type']}")
|
||||
else: raise RuntimeError(f"unknown content type: {type(content)}")
|
||||
if msg["role"] == "assistant" and i == len(body["messages"]) - 1: break
|
||||
ids += tok.end_turn(eos_id)
|
||||
ids += tok.end_turn()
|
||||
else: ids += tok.role("assistant")
|
||||
|
||||
# reply
|
||||
|
|
@ -623,6 +175,11 @@ class Handler(HTTPRequestHandler):
|
|||
else:
|
||||
raise RuntimeError(f"unhandled path {self.path}")
|
||||
|
||||
class LLMServer(TCPServerWithReuse):
|
||||
def __init__(self, server_address:tuple, model:Transformer, model_name:str, tok:SimpleTokenizer):
|
||||
self.model, self.model_name, self.tok = model, model_name, tok
|
||||
super().__init__(server_address, Handler)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", "-m", default=list(models.keys())[0], help=f"Model choice ({', '.join(models.keys())}) or path to a local GGUF file")
|
||||
|
|
@ -644,9 +201,6 @@ def main():
|
|||
gc.collect()
|
||||
|
||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||
bos_id: int|None = kv.get('tokenizer.ggml.bos_token_id') if kv.get('tokenizer.ggml.add_bos_token', True) else None
|
||||
eos_id: int = kv['tokenizer.ggml.eos_token_id']
|
||||
eot_id: int|None = kv.get('tokenizer.ggml.eot_token_id')
|
||||
|
||||
# warmup the JIT
|
||||
if args.warmup or args.serve:
|
||||
|
|
@ -655,15 +209,11 @@ def main():
|
|||
for _ in range(2): list(zip(range(2), model.generate([0])))
|
||||
|
||||
# start server
|
||||
if args.serve:
|
||||
server = LLMServer(('', args.serve), Handler)
|
||||
server.model, server.model_name, server.tok = model, model_name, tok
|
||||
server.bos_id, server.eos_id, server.eot_id = bos_id, eos_id, eot_id
|
||||
server.serve_forever()
|
||||
if args.serve: LLMServer(('', args.serve), model, model_name, tok).serve_forever()
|
||||
|
||||
# do benchmark
|
||||
if args.benchmark is not None:
|
||||
gen = model.generate(toks:=[bos_id or 0])
|
||||
gen = model.generate(toks:=[tok.bos_id or 0])
|
||||
for _ in range(args.benchmark):
|
||||
GlobalCounters.reset()
|
||||
with Timing(on_exit=lambda x: f", {1e9/x:6.2f} tok/s, {GlobalCounters.global_mem/x:7.2f} GB/s,"
|
||||
|
|
@ -672,16 +222,16 @@ def main():
|
|||
exit(0)
|
||||
|
||||
# interactive chat
|
||||
ids: list[int] = tok.prefix(bos_id)
|
||||
ids: list[int] = tok.prefix()
|
||||
while 1:
|
||||
try:
|
||||
ids += tok.role("user") + tok.encode(input('>>> ')) + tok.end_turn(eos_id) + tok.role("assistant")
|
||||
ids += tok.role("user") + tok.encode(input('>>> ')) + tok.end_turn() + tok.role("assistant")
|
||||
except EOFError:
|
||||
break
|
||||
dec = tok.stream_decoder()
|
||||
for next_id in model.generate(ids):
|
||||
sys.stdout.write(dec(next_id) if next_id not in (eos_id, eot_id) else dec() + "\n\n")
|
||||
sys.stdout.write(dec(next_id) if not tok.is_end(next_id) else dec() + "\n\n")
|
||||
sys.stdout.flush()
|
||||
if next_id in (eos_id, eot_id): break
|
||||
if tok.is_end(next_id): break
|
||||
|
||||
if __name__ == "__main__": main()
|
||||
409
tinygrad/llm/model.py
Normal file
409
tinygrad/llm/model.py
Normal file
|
|
@ -0,0 +1,409 @@
|
|||
from __future__ import annotations
|
||||
import functools, itertools
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function
|
||||
from tinygrad.uop.ops import resolve
|
||||
|
||||
@functools.cache
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
||||
return freqs.cos().cat(freqs.sin(), dim=-1).contiguous()
|
||||
|
||||
class ExpertWeights:
|
||||
"""Like nn.Linear but with num_experts dimension. Weight shape: (num_experts, out_features, in_features)."""
|
||||
def __init__(self, num_experts:int, in_features:int, out_features:int):
|
||||
self.weight = Tensor.zeros(num_experts, out_features, in_features)
|
||||
def __call__(self, sel:Tensor, x:Tensor) -> Tensor:
|
||||
# sel: (B, T, k), x: (B, T, 1, in) or (B, T, k, in) -> output: (B, T, k, out)
|
||||
return (x.unsqueeze(-2) @ self.weight[sel].transpose(-1, -2)).squeeze(-2)
|
||||
|
||||
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
||||
assert x.shape[-1] % 2 == 0
|
||||
cos, sin = freqs_cis.reshape(1, 1, x.shape[2], -1).chunk(2, dim=-1)
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return (x1 * cos - x2 * sin).cat(x2 * cos + x1 * sin, dim=-1)
|
||||
|
||||
def pairwise_topk(x: Tensor, k: int) -> tuple[Tensor, Tensor]:
|
||||
n = x.shape[-1]
|
||||
vals = Tensor.arange(n).reshape(1,1,n).cast(x.dtype).expand(x.shape)
|
||||
cmp = (x.unsqueeze(-1) > x.unsqueeze(-2)) | ((x.unsqueeze(-1) == x.unsqueeze(-2)) & \
|
||||
(Tensor.arange(n).reshape(1,1,n,1) < Tensor.arange(n).reshape(1,1,1,n)))
|
||||
sel = Tensor.zeros_like(x).scatter(-1, cmp.sum(axis=-1).cast('int32'), vals)[:,:,n-k:].cast('int32')
|
||||
return x.gather(-1, sel), sel
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SSMConfig:
|
||||
conv_kernel: int
|
||||
state_size: int
|
||||
group_count: int
|
||||
time_step_rank: int
|
||||
inner_size: int
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TransformerConfig:
|
||||
num_blocks: int
|
||||
dim: int
|
||||
hidden_dim: int
|
||||
n_heads: int
|
||||
n_kv_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
head_dim: int
|
||||
rope_theta: float
|
||||
rope_dim: int
|
||||
v_head_dim: int
|
||||
max_context: int = 0
|
||||
qk_norm: int = 0
|
||||
num_experts: int = 0
|
||||
num_experts_per_tok: int = 0
|
||||
norm_topk_prob: bool = False
|
||||
q_lora_rank: int = 0
|
||||
kv_lora_rank: int = 0
|
||||
shared_expert_dim: int = 0
|
||||
full_attention_interval: int = 0
|
||||
attn_output_gate: bool = False
|
||||
ssm: SSMConfig|None = None
|
||||
shared_expert_gate: bool = True
|
||||
leading_dense_blocks: int = 0
|
||||
dense_hidden_dim: int = 0
|
||||
routed_scaling_factor: float = 1.0
|
||||
|
||||
class FFNBlock:
|
||||
def __init__(self, config:TransformerConfig):
|
||||
self.config = config
|
||||
|
||||
# --- RMSNorms --------------------------------------------------------
|
||||
self.attn_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
|
||||
# --- feed-forward (MoE or dense) -------------------------------------
|
||||
if config.num_experts > 0:
|
||||
self.ffn_gate_inp = nn.Linear(config.dim, config.num_experts, bias=False) # router
|
||||
if config.kv_lora_rank > 0: self.exp_probs_b = {"bias": Tensor.zeros(config.num_experts)}
|
||||
self.ffn_gate_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
|
||||
self.ffn_up_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
|
||||
self.ffn_down_exps = ExpertWeights(config.num_experts, config.hidden_dim, config.dim)
|
||||
if config.shared_expert_dim > 0:
|
||||
self.ffn_gate_shexp = nn.Linear(config.dim, config.shared_expert_dim, bias=False)
|
||||
self.ffn_up_shexp = nn.Linear(config.dim, config.shared_expert_dim, bias=False)
|
||||
self.ffn_down_shexp = nn.Linear(config.shared_expert_dim, config.dim, bias=False)
|
||||
if config.shared_expert_gate: self.ffn_gate_inp_shexp = {"weight": Tensor.zeros(config.dim)}
|
||||
else:
|
||||
self.ffn_gate = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.ffn_up = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.ffn_down = nn.Linear(config.hidden_dim, config.dim, bias=False)
|
||||
|
||||
def _feed_forward(self, x:Tensor) -> Tensor:
|
||||
if hasattr(self, 'ffn_gate_exps'):
|
||||
h = x.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting
|
||||
logits = self.ffn_gate_inp(x)
|
||||
if hasattr(self, 'exp_probs_b'):
|
||||
probs = logits.sigmoid()
|
||||
_, sel = pairwise_topk(probs + self.exp_probs_b["bias"], self.config.num_experts_per_tok)
|
||||
probs = probs.gather(-1, sel)
|
||||
if self.config.norm_topk_prob: probs = probs / probs.sum(axis=-1, keepdim=True)
|
||||
else:
|
||||
vals, sel = pairwise_topk(logits, self.config.num_experts_per_tok)
|
||||
probs = vals.softmax(-1) if self.config.norm_topk_prob else logits.softmax(-1).gather(-1, sel)
|
||||
probs = probs * self.config.routed_scaling_factor
|
||||
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, h).silu() * self.ffn_up_exps(sel, h)) # (B, T, k, D)
|
||||
out = (x_down * probs.unsqueeze(-1)).sum(axis=2) # (B, T, D)
|
||||
if hasattr(self, 'ffn_gate_shexp'):
|
||||
shexp = self.ffn_down_shexp(self.ffn_gate_shexp(x).silu().contiguous() * self.ffn_up_shexp(x))
|
||||
if hasattr(self, 'ffn_gate_inp_shexp'): shexp = shexp * (x * self.ffn_gate_inp_shexp["weight"]).sum(axis=-1, keepdim=True).sigmoid()
|
||||
out = out + shexp
|
||||
return out
|
||||
# TODO: remove the need for this contiguous
|
||||
return self.ffn_down(self.ffn_gate(x).silu().contiguous() * self.ffn_up(x))
|
||||
|
||||
# given the token-prefix match, return how much cached state this block can still reuse
|
||||
def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return prefix_len
|
||||
# return writes that reset this block's state after a cache mismatch
|
||||
def _state_reset_ops(self) -> list[Tensor]: return []
|
||||
def _init_state(self, x:Tensor): raise NotImplementedError
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor: raise NotImplementedError
|
||||
|
||||
def __call__(self, x: Tensor, start_pos: int|UOp):
|
||||
self._init_state(x)
|
||||
# we pass in the weights implicitly so we unpack the GGUF on the fly
|
||||
@function(precompile=True, allow_implicit=True)
|
||||
def _run(x:Tensor, start_pos:int|UOp):
|
||||
h = x + self._attention(self.attn_norm(x), start_pos)
|
||||
return (h + self._feed_forward(self.ffn_norm(h))).contiguous()
|
||||
return _run(x, start_pos)
|
||||
|
||||
class TransformerBlock(FFNBlock):
|
||||
def __init__(self, config:TransformerConfig):
|
||||
super().__init__(config)
|
||||
assert config.v_head_dim == config.head_dim, "TransformerBlock requires v_head_dim == head_dim"
|
||||
|
||||
# --- attention projections (all linear, bias-free) ------------------
|
||||
q_proj_out = config.head_dim * config.n_heads * (2 if config.attn_output_gate else 1)
|
||||
kv_proj_out = config.head_dim * config.n_kv_heads
|
||||
self.attn_q = nn.Linear(config.dim, q_proj_out, bias=False)
|
||||
self.attn_k = nn.Linear(config.dim, kv_proj_out, bias=False)
|
||||
self.attn_v = nn.Linear(config.dim, kv_proj_out, bias=False)
|
||||
self.attn_output = nn.Linear(config.head_dim * config.n_heads, config.dim, bias=False)
|
||||
if config.qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(config.qk_norm, config.norm_eps), nn.RMSNorm(config.qk_norm, config.norm_eps)
|
||||
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
q, k, v = self.attn_q(x), self.attn_k(x), self.attn_v(x)
|
||||
if self.config.qk_norm and self.config.qk_norm != self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
|
||||
B, T, _ = x.shape
|
||||
if self.config.attn_output_gate:
|
||||
qg = q.reshape(B, T, self.config.n_heads, 2, self.config.head_dim)
|
||||
q, gate = qg[:, :, :, 0, :], qg[:, :, :, 1, :].reshape(B, T, self.config.n_heads * self.config.head_dim)
|
||||
q = q.reshape(B, T, self.config.n_heads, self.config.head_dim).transpose(1, 2) # (B,H,T,Hd)
|
||||
k = k.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
v = v.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
if self.config.qk_norm == self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
|
||||
q = apply_rope(q[..., :self.config.rope_dim], self.freqs_cis[start_pos:start_pos+T]).cat(q[..., self.config.rope_dim:], dim=-1)
|
||||
k = apply_rope(k[..., :self.config.rope_dim], self.freqs_cis[start_pos:start_pos+T]).cat(k[..., self.config.rope_dim:], dim=-1)
|
||||
|
||||
# NOTE: we don't want to change self.cache_kv, the function API doesn't support this well
|
||||
assigned_kv = Tensor(self.cache_kv.uop.after(self.cache_kv[:, :, :, start_pos:start_pos+T, :].uop.store(Tensor.stack(k, v).uop)))
|
||||
k = assigned_kv[0, :, :, 0:start_pos+T, :]
|
||||
v = assigned_kv[1, :, :, 0:start_pos+T, :]
|
||||
|
||||
#self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
|
||||
#k = self.cache_kv[0, :, :, 0:start_pos+T, :]
|
||||
#v = self.cache_kv[1, :, :, 0:start_pos+T, :]
|
||||
|
||||
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
||||
# TODO: this if statement should be removed and it shouldn't generate extra kernels
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
|
||||
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
|
||||
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
|
||||
return self.attn_output(attn if not self.config.attn_output_gate else (attn * gate.sigmoid()))
|
||||
|
||||
def _init_state(self, x:Tensor):
|
||||
if not hasattr(self, "cache_kv"):
|
||||
# TODO: how is the dtype of this determined?
|
||||
self.cache_kv = Tensor.empty(2, x.shape[0], self.config.n_kv_heads, self.config.max_context, self.config.head_dim, device=x.device)
|
||||
self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta)
|
||||
|
||||
class MLATransformerBlock(FFNBlock):
|
||||
def __init__(self, config:TransformerConfig):
|
||||
super().__init__(config)
|
||||
qk_nope_head_dim = config.head_dim - config.rope_dim
|
||||
if config.q_lora_rank > 0:
|
||||
self.attn_q_a = nn.Linear(config.dim, config.q_lora_rank, bias=False)
|
||||
self.attn_q_a_norm = nn.RMSNorm(config.q_lora_rank, config.norm_eps)
|
||||
self.attn_q_b = nn.Linear(config.q_lora_rank, config.n_heads * config.head_dim, bias=False)
|
||||
else:
|
||||
self.attn_q = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False)
|
||||
self.attn_kv_a_mqa = nn.Linear(config.dim, config.kv_lora_rank + config.rope_dim, bias=False)
|
||||
self.attn_kv_a_norm = nn.RMSNorm(config.kv_lora_rank, config.norm_eps)
|
||||
self.attn_k_b = {"weight": Tensor.zeros(config.n_heads, config.kv_lora_rank, qk_nope_head_dim)}
|
||||
self.attn_v_b = {"weight": Tensor.zeros(config.n_heads, config.v_head_dim, config.kv_lora_rank)}
|
||||
self.attn_output = nn.Linear(config.n_heads * config.v_head_dim, config.dim, bias=False)
|
||||
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
B, T, _ = x.shape
|
||||
q_nope_head_dim = self.config.head_dim - self.config.rope_dim
|
||||
q_proj = self.attn_q_b(self.attn_q_a_norm(self.attn_q_a(x))) if self.config.q_lora_rank > 0 else self.attn_q(x)
|
||||
q = q_proj.reshape(B, T, self.config.n_heads, self.config.head_dim).transpose(1, 2)
|
||||
q_nope, q_rope = q[..., :q_nope_head_dim], q[..., q_nope_head_dim:]
|
||||
q = (q_nope @ self.attn_k_b["weight"].transpose(-1, -2)).cat(apply_rope(q_rope, self.freqs_cis[start_pos:start_pos+T]), dim=-1)
|
||||
|
||||
kv_a = self.attn_kv_a_mqa(x)
|
||||
c_kv = self.attn_kv_a_norm(kv_a[..., :self.config.kv_lora_rank])
|
||||
k_rope = apply_rope(
|
||||
kv_a[..., self.config.kv_lora_rank:].reshape(B, T, 1, self.config.rope_dim).transpose(1, 2),
|
||||
self.freqs_cis[start_pos:start_pos+T])
|
||||
|
||||
k_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank).cat(k_rope.reshape(B, 1, T, self.config.rope_dim), dim=-1)
|
||||
v_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank)
|
||||
k = Tensor(self.cache_k.uop.after(self.cache_k[:, :, start_pos:start_pos+T, :].uop.store(k_store.uop)))[:, :, 0:start_pos+T, :]
|
||||
v = Tensor(self.cache_v.uop.after(self.cache_v[:, :, start_pos:start_pos+T, :].uop.store(v_store.uop)))[:, :, 0:start_pos+T, :]
|
||||
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
|
||||
attn = q @ k.transpose(-1, -2) * (1.0 / self.config.head_dim ** 0.5)
|
||||
if mask is not None: attn = attn + mask
|
||||
attn = attn.softmax(-1)
|
||||
attn = ((attn @ v) @ self.attn_v_b["weight"].transpose(-1, -2)).transpose(1, 2).reshape(B, T, -1)
|
||||
return self.attn_output(attn)
|
||||
|
||||
def _init_state(self, x:Tensor):
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank + self.config.rope_dim, device=x.device)
|
||||
self.cache_v = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank, device=x.device)
|
||||
self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta)
|
||||
|
||||
class GatedDeltaNetBlock(FFNBlock):
|
||||
def __init__(self, config:TransformerConfig, ssm:SSMConfig):
|
||||
super().__init__(config)
|
||||
self.head_k_dim, self.num_k_heads, self.num_v_heads = ssm.state_size, ssm.group_count, ssm.time_step_rank
|
||||
assert self.num_v_heads % self.num_k_heads == 0
|
||||
self.head_v_dim, self.ssm_conv_kernel = ssm.inner_size // ssm.time_step_rank, ssm.conv_kernel
|
||||
self.conv_channels, self.q_dim = ssm.inner_size + 2*ssm.group_count*ssm.state_size, ssm.state_size*ssm.group_count
|
||||
self.attn_qkv, self.attn_gate = nn.Linear(config.dim, self.conv_channels, bias=False), nn.Linear(config.dim, ssm.inner_size, bias=False)
|
||||
self.ssm_alpha, self.ssm_beta = nn.Linear(config.dim, self.num_v_heads, bias=False), nn.Linear(config.dim, self.num_v_heads, bias=False)
|
||||
self.ssm_conv1d = {"weight": Tensor.zeros(self.conv_channels, self.ssm_conv_kernel)}
|
||||
self.ssm_dt = {"bias": Tensor.zeros(self.num_v_heads)}
|
||||
self.ssm_a = Tensor.zeros(self.num_v_heads)
|
||||
self.ssm_norm, self.ssm_out = nn.RMSNorm(self.head_v_dim, config.norm_eps), nn.Linear(ssm.inner_size, config.dim, bias=False)
|
||||
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
B, T, _ = x.shape
|
||||
assert T == 1, "GatedDeltaNetBlock currently only supports T=1"
|
||||
|
||||
# input processing
|
||||
x = x.half()
|
||||
out_gate = self.attn_gate(x).reshape(B, 1, self.num_v_heads, self.head_v_dim)
|
||||
beta = self.ssm_beta(x).sigmoid().reshape(B, self.num_v_heads, 1, 1)
|
||||
alpha = ((self.ssm_alpha(x).float() + self.ssm_dt["bias"]).softplus() * self.ssm_a).reshape(B, self.num_v_heads, 1, 1).exp()
|
||||
|
||||
# qkv conv
|
||||
conv_window = self.conv_state.cat(self.attn_qkv(x), dim=1)
|
||||
conv_out = (conv_window * self.ssm_conv1d["weight"].T.unsqueeze(0)).sum(1).silu()
|
||||
q, k, v = conv_out.split([self.q_dim, self.q_dim, self.conv_channels - 2*self.q_dim], dim=-1)
|
||||
q = q.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1).repeat(1, self.num_v_heads//self.num_k_heads, 1)
|
||||
k = k.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1).repeat(1, self.num_v_heads//self.num_k_heads, 1)
|
||||
v = v.reshape(B, self.num_v_heads, self.head_v_dim)
|
||||
q, k, v = q.mul(self.head_k_dim**-0.5).unsqueeze(-1), k.unsqueeze(-1), v.unsqueeze(-1)
|
||||
|
||||
# recurrent
|
||||
recurrent_state = self.recurrent_state * alpha
|
||||
recurrent_state = recurrent_state + ((v - recurrent_state@k) * beta)@k.transpose(-1, -2)
|
||||
|
||||
# store the updated state
|
||||
conv_state_store = self.conv_state.uop.store(conv_window[:, 1:, :].cast(self.conv_state.dtype).uop)
|
||||
recurrent_state_store = self.recurrent_state.uop.store(recurrent_state.cast(self.recurrent_state.dtype).uop)
|
||||
recurrent_state = Tensor(self.recurrent_state.uop.after(recurrent_state_store, conv_state_store))
|
||||
|
||||
# output
|
||||
core_attn_out = self.ssm_norm((recurrent_state@q).squeeze(-1).reshape(B, 1, self.num_v_heads, self.head_v_dim))
|
||||
return self.ssm_out((core_attn_out * out_gate.silu()).reshape(B, 1, -1).cast(x.dtype))
|
||||
|
||||
# recurrent state can't be partially reused after divergence, force a full rebuild
|
||||
def _state_reset_ops(self):
|
||||
return [self.conv_state.assign(Tensor.zeros_like(self.conv_state)),
|
||||
self.recurrent_state.assign(Tensor.zeros_like(self.recurrent_state))] if hasattr(self, "conv_state") else []
|
||||
def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return 0 if prefix_len != cached_len else prefix_len
|
||||
|
||||
def _init_state(self, x):
|
||||
if not hasattr(self, "conv_state"):
|
||||
self.conv_state = Tensor.zeros(x.shape[0], self.ssm_conv_kernel-1, self.conv_channels, device=x.device).clone()
|
||||
self.recurrent_state = Tensor.zeros(x.shape[0], self.num_v_heads, self.head_v_dim, self.head_v_dim, device=x.device).clone()
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, config:TransformerConfig):
|
||||
dense_config = replace(config, num_experts=0, num_experts_per_tok=0, shared_expert_dim=0, hidden_dim=config.dense_hidden_dim or config.hidden_dim)
|
||||
if config.ssm: config = replace(config, qk_norm=config.head_dim)
|
||||
block_cls = MLATransformerBlock if config.kv_lora_rank > 0 else TransformerBlock
|
||||
self.blk:list[FFNBlock] = [GatedDeltaNetBlock(config, config.ssm) if config.ssm and (i+1) % config.full_attention_interval != 0 else
|
||||
block_cls(dense_config if i < config.leading_dense_blocks else config) for i in range(config.num_blocks)]
|
||||
self.token_embd = nn.Embedding(config.vocab_size, config.dim)
|
||||
self.output_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
|
||||
self.max_context = config.max_context
|
||||
self.has_recurrent_block = any(isinstance(b, GatedDeltaNetBlock) for b in self.blk)
|
||||
self._cached_tokens: list[int] = []
|
||||
# we specialize the JIT for prefill and rollout
|
||||
self.prefill_jit = TinyJit(self.forward)
|
||||
self.rollout_jit = TinyJit(self.forward)
|
||||
|
||||
def forward(self, tokens:Tensor, start_pos:int|UOp, temperature:Tensor) -> Tensor:
|
||||
x = self.token_embd(tokens).float() # (B, T, D)
|
||||
for block in self.blk: x = block(x, start_pos)
|
||||
logits = self.output(self.output_norm(x))[:, -1, :]
|
||||
# Gumbel-max trick: argmax(logits/temp - log(-log(uniform))) is equivalent to sampling from softmax(logits/temp)
|
||||
return (logits / temperature.maximum(1e-12) - (Tensor.rand_like(logits).maximum(1e-12).log().neg()).log()).argmax(-1, keepdim=True)
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int|UOp, temperature:Tensor) -> Tensor:
|
||||
return (self.prefill_jit if resolve(tokens.shape[1] != 1) else self.rollout_jit)(tokens.contiguous(), start_pos, temperature)
|
||||
|
||||
@staticmethod
|
||||
def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALIZE", 0))) -> tuple[Transformer, dict]:
|
||||
# TODO: remove the need for copy to default device
|
||||
kv, state_dict = nn.state.gguf_load(gguf.to(None).realize())
|
||||
|
||||
# all state items should be float16, not float32
|
||||
state_dict = {k:v.cast('float16') if getenv("HALF", 1) else v for k,v in state_dict.items()}
|
||||
|
||||
# some models like Llama 3.2 don't have an output.weight, they just tie to the token_embd.weight
|
||||
if 'output.weight' not in state_dict: state_dict['output.weight'] = state_dict['token_embd.weight']
|
||||
|
||||
arch = kv['general.architecture']
|
||||
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
|
||||
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
|
||||
|
||||
ssm = None
|
||||
if arch in ('qwen35', 'qwen35moe'):
|
||||
ssm = SSMConfig(**{k: kv[f'{arch}.ssm.{k}'] for k in ('conv_kernel','state_size','group_count','time_step_rank','inner_size')})
|
||||
state_dict = {k.replace('post_attention_norm', 'ffn_norm'):v for k,v in state_dict.items()}
|
||||
|
||||
kv_lora_rank = kv.get(f'{arch}.attention.kv_lora_rank', 0)
|
||||
head_dim = kv.get(f'{arch}.attention.key_length_mla', kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads))
|
||||
rope_dim = kv.get(f'{arch}.rope.dimension_count', head_dim)
|
||||
|
||||
# Permute RoPE weights from interleaved to half-split layout.
|
||||
for name in state_dict:
|
||||
if ('attn_q.weight' in name or 'attn_q_b.weight' in name) and (arch == 'llama' or kv_lora_rank):
|
||||
w = state_dict[name].reshape(n_heads, state_dict[name].shape[0]//n_heads, -1)
|
||||
prefix = head_dim-rope_dim
|
||||
state_dict[name] = w[:, :prefix].cat(w[:, prefix:].rearrange("n (h two) d -> n (two h) d", two=2), dim=1).reshape(-1, w.shape[-1])
|
||||
elif arch == 'llama' and 'attn_k.weight' in name:
|
||||
w = state_dict[name].reshape(n_kv_heads, state_dict[name].shape[0]//n_kv_heads, -1)
|
||||
state_dict[name] = w.rearrange("n (h two) d -> n (two h) d", two=2).reshape(-1, w.shape[-1])
|
||||
elif kv_lora_rank and 'attn_kv_a_mqa.weight' in name:
|
||||
state_dict[name] = state_dict[name][:kv_lora_rank].cat(state_dict[name][kv_lora_rank:].rearrange("(h two) d -> (two h) d", two=2), dim=0)
|
||||
config = TransformerConfig(
|
||||
num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
|
||||
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv.get(f'{arch}.feed_forward_length', 0)),
|
||||
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
|
||||
vocab_size=len(kv['tokenizer.ggml.tokens']),
|
||||
head_dim=head_dim,
|
||||
rope_theta=kv[f'{arch}.rope.freq_base'],
|
||||
rope_dim=rope_dim,
|
||||
v_head_dim=kv.get(f'{arch}.attention.value_length_mla', kv.get(f'{arch}.attention.value_length', head_dim)),
|
||||
max_context=max_context,
|
||||
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
|
||||
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0),
|
||||
norm_topk_prob=kv.get(f'{arch}.expert_weights_norm', arch in ('qwen3moe', 'qwen35moe')),
|
||||
kv_lora_rank=kv_lora_rank, q_lora_rank=kv.get(f'{arch}.attention.q_lora_rank', 0),
|
||||
leading_dense_blocks=kv.get(f'{arch}.leading_dense_block_count', 0),
|
||||
shared_expert_dim=kv.get(
|
||||
f'{arch}.expert_shared_feed_forward_length',
|
||||
kv.get(f'{arch}.expert_shared_count', 0) * kv.get(f'{arch}.expert_feed_forward_length', 0)),
|
||||
shared_expert_gate=f"blk.{kv.get(f'{arch}.leading_dense_block_count', 0)}.ffn_gate_inp_shexp.weight" in state_dict,
|
||||
dense_hidden_dim=kv.get(f'{arch}.feed_forward_length', 0) if kv.get(f'{arch}.leading_dense_block_count', 0) else 0,
|
||||
routed_scaling_factor=kv.get(f'{arch}.expert_weights_scale', 1.0), attn_output_gate=arch in ('qwen35', 'qwen35moe'), ssm=ssm,
|
||||
full_attention_interval=kv.get(f'{arch}.full_attention_interval', 0))
|
||||
model = Transformer(config)
|
||||
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
||||
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
||||
if realize:
|
||||
for s in (params:=nn.state.get_parameters(model)): s.replace(s.contiguous())
|
||||
Tensor.realize(*params)
|
||||
return model, kv
|
||||
|
||||
def get_start_pos(self, tokens:list[int]) -> int:
|
||||
prefix_len = sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens[:-1], self._cached_tokens)))
|
||||
return min(block._reusable_prefix_len(prefix_len, len(self._cached_tokens)) for block in self.blk)
|
||||
|
||||
def generate(self, tokens:list[int], chunk_size:int=32, temperature:float=0.0):
|
||||
if self.has_recurrent_block: chunk_size = 1
|
||||
v_start_pos = UOp.variable("start_pos", 0, self.max_context-1)
|
||||
v_toks = UOp.variable("toks", 1, chunk_size)
|
||||
# TODO: use UOp.variable for temperature once float variables are supported
|
||||
temp = Tensor(temperature).contiguous()
|
||||
# assign all input tokens once, then slice from start_pos for the model call
|
||||
t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context)
|
||||
# recompute start_pos from what's currently valid in the caches
|
||||
start_pos = self.get_start_pos(tokens)
|
||||
if start_pos < len(self._cached_tokens) and (resets := [r for b in self.blk for r in b._state_reset_ops()]): Tensor.realize(*resets)
|
||||
out, prompt_len = None, len(tokens)
|
||||
while len(tokens) < self.max_context:
|
||||
sp, nt = v_start_pos.bind(start_pos), v_toks.bind(min(chunk_size, len(tokens) - start_pos))
|
||||
out = self(t[:, sp:sp+nt] if start_pos < prompt_len or out is None else out, sp, temp).realize()
|
||||
start_pos += nt.val
|
||||
# chunked prefill: keep processing until all prompt tokens are consumed
|
||||
if start_pos < len(tokens): continue
|
||||
tokens.append(int(out.item()))
|
||||
self._cached_tokens = tokens[:-1]
|
||||
yield tokens[-1]
|
||||
Loading…
Add table
Add a link
Reference in a new issue