mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
gemma4_gpt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
19c0e4a11d |
1 changed files with 245 additions and 50 deletions
|
|
@ -9,7 +9,7 @@ from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
|
|||
class SimpleTokenizer:
|
||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"):
|
||||
preset = {"qwen35":"qwen2","qwen35moe":"qwen2"}.get(preset, preset)
|
||||
if preset not in ("llama3","llama-v3","llama-bpe","qwen2","olmo","kimi-k2"): raise ValueError(f"Invalid tokenizer preset '{preset}'")
|
||||
if preset not in ("llama3","llama-v3","llama-bpe","qwen2","olmo","kimi-k2","gemma4"): raise ValueError(f"Invalid tokenizer preset '{preset}'")
|
||||
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
||||
bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves
|
||||
self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
|
||||
|
|
@ -22,9 +22,11 @@ class SimpleTokenizer:
|
|||
f"[^\\r\\n{r_p_N}{r_p_L}]?[{r_p_L}]+|[{r_p_N}]{{1,3}}| ?[^{r_ws}{r_p_N}{r_p_L}]+[\\r\\n]*|[{r_ws}]*[\\r\\n]+|[{r_ws}]+(?![^{r_ws}])|[{r_ws}]+")
|
||||
self._split_to_sentence = re.compile("|".join(re.escape(tok) for tok in special_tokens.keys()) if special_tokens else r"(?!)")
|
||||
|
||||
self._normal_tokens = {bytes(self._byte_decoder[c] for c in tok): tid for tok, tid in normal_tokens.items()}
|
||||
tok_bytes = (lambda tok: tok.replace("▁", " ").encode()) if preset == "gemma4" else (lambda tok: bytes(self._byte_decoder[c] for c in tok))
|
||||
self._normal_tokens = {tok_bytes(tok): tid for tok, tid in normal_tokens.items()}
|
||||
self._special_tokens = special_tokens
|
||||
self._tok2bytes = {tid: tok for tok, tid in self._normal_tokens.items()} | {tid: tok.encode() for tok, tid in self._special_tokens.items()}
|
||||
self._tok2bytes = {tid: tok for tok, tid in self._normal_tokens.items()} | {
|
||||
tid: (b'' if preset == "gemma4" else tok.encode()) for tok, tid in self._special_tokens.items()}
|
||||
self.preset = preset
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -32,7 +34,9 @@ class SimpleTokenizer:
|
|||
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
|
||||
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
|
||||
normal_tokens, special_tokens = partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
|
||||
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens), kv["tokenizer.ggml.pre"])
|
||||
return SimpleTokenizer(
|
||||
dict(normal_tokens), dict(special_tokens),
|
||||
kv.get("tokenizer.ggml.pre") or kv.get("tokenizer.ggml.model", "llama3"))
|
||||
|
||||
def _encode_word(self, word:bytes) -> list[int]:
|
||||
if (early_token:=self._normal_tokens.get(word)) is not None: return [early_token]
|
||||
|
|
@ -63,11 +67,13 @@ class SimpleTokenizer:
|
|||
if self.preset == 'olmo': return self.encode("<|" + role + "|>\n") # OLMoE Instruct format
|
||||
if self.preset == 'kimi-k2': return self.encode("<|im_" + role + "|>" + role + "<|im_middle|>")
|
||||
if self.preset == 'qwen2': return self.encode("<|im_start|>" + role + "\n")
|
||||
if self.preset == 'gemma4': return self.encode("<|turn>" + ("model" if role == "assistant" else role) + "\n")
|
||||
return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
||||
def end_turn(self, eos_id:int):
|
||||
if self.preset == 'olmo': return self.encode("\n")
|
||||
if self.preset == 'kimi-k2': return [eos_id]
|
||||
if self.preset == 'qwen2': return [eos_id] + self.encode("\n")
|
||||
if self.preset == 'gemma4': return self.encode("<turn|>\n")
|
||||
return [eos_id]
|
||||
|
||||
@functools.cache
|
||||
|
|
@ -90,6 +96,27 @@ def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
|||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return (x1 * cos - x2 * sin).cat(x2 * cos + x1 * sin, dim=-1)
|
||||
|
||||
class ScaledLinear:
|
||||
def __init__(self, in_features:int, out_features:int):
|
||||
self.weight = Tensor.zeros(out_features, in_features)
|
||||
self.scale = Tensor.ones(in_features)
|
||||
def __call__(self, x:Tensor) -> Tensor: return (x * self.scale) @ self.weight.transpose(-1, -2)
|
||||
|
||||
class ScaledExpertWeights(ExpertWeights):
|
||||
def __init__(self, num_experts:int, in_features:int, out_features:int):
|
||||
super().__init__(num_experts, in_features, out_features)
|
||||
self.scale = Tensor.ones(num_experts)
|
||||
|
||||
class ScalarWeight:
|
||||
def __init__(self): self.weight = Tensor.ones(1)
|
||||
|
||||
def rms_norm_no_weight(x:Tensor, eps:float) -> Tensor:
|
||||
return x * (x.square().mean(axis=-1, keepdim=True) + eps).rsqrt()
|
||||
|
||||
def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
|
||||
return x if n_rep == 1 else x.unsqueeze(2).expand(
|
||||
x.shape[0], x.shape[1], n_rep, x.shape[2], x.shape[3]).reshape(x.shape[0], x.shape[1] * n_rep, x.shape[2], x.shape[3])
|
||||
|
||||
def pairwise_topk(x: Tensor, k: int) -> tuple[Tensor, Tensor]:
|
||||
n = x.shape[-1]
|
||||
vals = Tensor.arange(n).reshape(1,1,n).cast(x.dtype).expand(x.shape)
|
||||
|
|
@ -110,17 +137,17 @@ class SSMConfig:
|
|||
class TransformerConfig:
|
||||
num_blocks: int
|
||||
dim: int
|
||||
hidden_dim: int
|
||||
hidden_dim: int|tuple[int, ...]
|
||||
n_heads: int
|
||||
n_kv_heads: int
|
||||
n_kv_heads: int|tuple[int, ...]
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
head_dim: int
|
||||
rope_theta: float
|
||||
head_dim: int|tuple[int, ...]
|
||||
rope_theta: float|tuple[float, ...]
|
||||
rope_dim: int
|
||||
v_head_dim: int
|
||||
max_context: int = 0
|
||||
qk_norm: int = 0
|
||||
qk_norm: int|tuple[int, ...] = 0
|
||||
num_experts: int = 0
|
||||
num_experts_per_tok: int = 0
|
||||
norm_topk_prob: bool = False
|
||||
|
|
@ -133,36 +160,66 @@ class TransformerConfig:
|
|||
leading_dense_blocks: int = 0
|
||||
dense_hidden_dim: int = 0
|
||||
routed_scaling_factor: float = 1.0
|
||||
sliding_window: int = 0
|
||||
sliding_window_pattern: tuple[bool, ...] = ()
|
||||
per_layer_input_dim: int = 0
|
||||
final_logit_softcap: float = 0.0
|
||||
num_kv_shared_layers: int = 0
|
||||
gemma4: bool = False
|
||||
expert_hidden_dim: int = 0
|
||||
|
||||
class FFNBlock:
|
||||
def __init__(self, config:TransformerConfig):
|
||||
self.config = config
|
||||
self.hidden_dim = config.hidden_dim
|
||||
gemma_moe = config.gemma4 and config.num_experts > 0
|
||||
|
||||
# --- RMSNorms --------------------------------------------------------
|
||||
self.attn_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
|
||||
# --- feed-forward (MoE or dense) -------------------------------------
|
||||
if config.num_experts > 0:
|
||||
self.ffn_gate_inp: nn.Linear|ScaledLinear
|
||||
self.ffn_down_exps: ExpertWeights|ScaledExpertWeights
|
||||
if gemma_moe or config.num_experts == 0:
|
||||
self.ffn_gate = nn.Linear(config.dim, self.hidden_dim, bias=False)
|
||||
self.ffn_up = nn.Linear(config.dim, self.hidden_dim, bias=False)
|
||||
self.ffn_down = nn.Linear(self.hidden_dim, config.dim, bias=False)
|
||||
if gemma_moe:
|
||||
self.ffn_gate_inp = ScaledLinear(config.dim, config.num_experts)
|
||||
self.ffn_gate_up_exps = ExpertWeights(config.num_experts, config.dim, config.expert_hidden_dim * 2)
|
||||
self.ffn_down_exps = ScaledExpertWeights(config.num_experts, config.expert_hidden_dim, config.dim)
|
||||
self.post_ffw_norm_1 = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
self.pre_ffw_norm_2 = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
self.post_ffw_norm_2 = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
elif config.num_experts > 0:
|
||||
self.ffn_gate_inp = nn.Linear(config.dim, config.num_experts, bias=False) # router
|
||||
if config.kv_lora_rank > 0: self.exp_probs_b = {"bias": Tensor.zeros(config.num_experts)}
|
||||
self.ffn_gate_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
|
||||
self.ffn_up_exps = ExpertWeights(config.num_experts, config.dim, config.hidden_dim)
|
||||
self.ffn_down_exps = ExpertWeights(config.num_experts, config.hidden_dim, config.dim)
|
||||
self.ffn_gate_exps = ExpertWeights(config.num_experts, config.dim, self.hidden_dim)
|
||||
self.ffn_up_exps = ExpertWeights(config.num_experts, config.dim, self.hidden_dim)
|
||||
self.ffn_down_exps = ExpertWeights(config.num_experts, self.hidden_dim, config.dim)
|
||||
if config.shared_expert_dim > 0:
|
||||
self.ffn_gate_shexp = nn.Linear(config.dim, config.shared_expert_dim, bias=False)
|
||||
self.ffn_up_shexp = nn.Linear(config.dim, config.shared_expert_dim, bias=False)
|
||||
self.ffn_down_shexp = nn.Linear(config.shared_expert_dim, config.dim, bias=False)
|
||||
if config.shared_expert_gate: self.ffn_gate_inp_shexp = {"weight": Tensor.zeros(config.dim)}
|
||||
else:
|
||||
self.ffn_gate = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.ffn_up = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.ffn_down = nn.Linear(config.hidden_dim, config.dim, bias=False)
|
||||
|
||||
def _feed_forward(self, x:Tensor) -> Tensor:
|
||||
h_norm = self.ffn_norm(x) if self.config.gemma4 else x
|
||||
if self.config.gemma4 and self.config.num_experts > 0:
|
||||
ffn_gate_inp = typing.cast(ScaledLinear, self.ffn_gate_inp)
|
||||
ffn_down_exps = typing.cast(ScaledExpertWeights, self.ffn_down_exps)
|
||||
dense = self.post_ffw_norm_1(self.ffn_down((self.ffn_gate(h_norm).gelu().contiguous()) * self.ffn_up(h_norm)))
|
||||
router_probs = (
|
||||
rms_norm_no_weight(x, self.config.norm_eps) * (self.config.dim ** -0.5) * ffn_gate_inp.scale
|
||||
) @ ffn_gate_inp.weight.transpose(-1, -2)
|
||||
vals, sel = pairwise_topk(router_probs.softmax(-1), self.config.num_experts_per_tok)
|
||||
probs = vals / vals.sum(axis=-1, keepdim=True) * ffn_down_exps.scale[sel]
|
||||
gate, up = self.ffn_gate_up_exps(sel, self.pre_ffw_norm_2(x).unsqueeze(2)).chunk(2, dim=-1)
|
||||
return dense + self.post_ffw_norm_2((ffn_down_exps(sel, gate.gelu().contiguous() * up) * probs.unsqueeze(-1)).sum(axis=2))
|
||||
if hasattr(self, 'ffn_gate_exps'):
|
||||
h = x.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting
|
||||
logits = self.ffn_gate_inp(x)
|
||||
h = h_norm.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting
|
||||
logits = self.ffn_gate_inp(h_norm)
|
||||
if hasattr(self, 'exp_probs_b'):
|
||||
probs = logits.sigmoid()
|
||||
_, sel = pairwise_topk(probs + self.exp_probs_b["bias"], self.config.num_experts_per_tok)
|
||||
|
|
@ -175,12 +232,13 @@ class FFNBlock:
|
|||
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, h).silu() * self.ffn_up_exps(sel, h)) # (B, T, k, D)
|
||||
out = (x_down * probs.unsqueeze(-1)).sum(axis=2) # (B, T, D)
|
||||
if hasattr(self, 'ffn_gate_shexp'):
|
||||
shexp = self.ffn_down_shexp(self.ffn_gate_shexp(x).silu().contiguous() * self.ffn_up_shexp(x))
|
||||
if hasattr(self, 'ffn_gate_inp_shexp'): shexp = shexp * (x * self.ffn_gate_inp_shexp["weight"]).sum(axis=-1, keepdim=True).sigmoid()
|
||||
shexp = self.ffn_down_shexp(self.ffn_gate_shexp(h_norm).silu().contiguous() * self.ffn_up_shexp(h_norm))
|
||||
if hasattr(self, 'ffn_gate_inp_shexp'): shexp = shexp * (h_norm * self.ffn_gate_inp_shexp["weight"]).sum(axis=-1, keepdim=True).sigmoid()
|
||||
out = out + shexp
|
||||
return out
|
||||
# TODO: remove the need for this contiguous
|
||||
return self.ffn_down(self.ffn_gate(x).silu().contiguous() * self.ffn_up(x))
|
||||
act = self.ffn_gate(h_norm).gelu() if self.config.gemma4 else self.ffn_gate(h_norm).silu()
|
||||
return self.ffn_down(act.contiguous() * self.ffn_up(h_norm))
|
||||
|
||||
# given the token-prefix match, return how much cached state this block can still reuse
|
||||
def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return prefix_len
|
||||
|
|
@ -201,29 +259,78 @@ class FFNBlock:
|
|||
class TransformerBlock(FFNBlock):
|
||||
def __init__(self, config:TransformerConfig):
|
||||
super().__init__(config)
|
||||
assert config.v_head_dim == config.head_dim, "TransformerBlock requires v_head_dim == head_dim"
|
||||
self.head_dim = config.head_dim
|
||||
self.rope_theta = config.rope_theta
|
||||
self.qk_norm = config.qk_norm
|
||||
self.n_kv_heads = config.n_kv_heads
|
||||
self.is_sliding = config.sliding_window > 0 and bool(config.sliding_window_pattern) and config.sliding_window_pattern[0]
|
||||
self.use_alternative_attention = config.gemma4 and config.num_experts > 0 and not self.is_sliding
|
||||
self.store_full_length_kv = False
|
||||
self.shared_kv_src_idx: int|None = None
|
||||
self.full_kv_cache: Tensor|None = None
|
||||
if not config.gemma4: assert config.v_head_dim == self.head_dim, "TransformerBlock requires v_head_dim == head_dim"
|
||||
|
||||
# --- attention projections (all linear, bias-free) ------------------
|
||||
q_proj_out = config.head_dim * config.n_heads * (2 if config.attn_output_gate else 1)
|
||||
kv_proj_out = config.head_dim * config.n_kv_heads
|
||||
q_proj_out = self.head_dim * config.n_heads * (2 if config.attn_output_gate else 1)
|
||||
kv_proj_out = self.head_dim * self.n_kv_heads
|
||||
self.attn_q = nn.Linear(config.dim, q_proj_out, bias=False)
|
||||
self.attn_k = nn.Linear(config.dim, kv_proj_out, bias=False)
|
||||
self.attn_v = nn.Linear(config.dim, kv_proj_out, bias=False)
|
||||
self.attn_output = nn.Linear(config.head_dim * config.n_heads, config.dim, bias=False)
|
||||
if config.qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(config.qk_norm, config.norm_eps), nn.RMSNorm(config.qk_norm, config.norm_eps)
|
||||
if not self.use_alternative_attention: self.attn_v = nn.Linear(config.dim, kv_proj_out, bias=False)
|
||||
self.attn_output = nn.Linear(self.head_dim * config.n_heads, config.dim, bias=False)
|
||||
if self.qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(self.qk_norm, config.norm_eps), nn.RMSNorm(self.qk_norm, config.norm_eps)
|
||||
if config.gemma4:
|
||||
self.post_attention_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
self.post_ffw_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
self.layer_output_scale = ScalarWeight()
|
||||
if config.per_layer_input_dim:
|
||||
self.inp_gate = nn.Linear(config.dim, config.per_layer_input_dim, bias=False)
|
||||
self.proj = nn.Linear(config.per_layer_input_dim, config.dim, bias=False)
|
||||
self.post_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp, shared_kv_cache:Tensor|None=None) -> Tensor:
|
||||
if self.config.gemma4:
|
||||
x_norm = self.attn_norm(x)
|
||||
q, k = self.attn_q(x_norm), self.attn_k(x_norm)
|
||||
if self.qk_norm and self.qk_norm != self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
|
||||
B, T, _ = x.shape
|
||||
q = q.reshape(B, T, self.config.n_heads, self.head_dim).transpose(1, 2)
|
||||
if self.qk_norm == self.head_dim: q = self.attn_q_norm(q)
|
||||
q = apply_rope(q, self.freqs_cis[start_pos:start_pos+T])
|
||||
if shared_kv_cache is not None:
|
||||
k = shared_kv_cache[0, :, :, 0:start_pos+T, :]
|
||||
v = shared_kv_cache[1, :, :, 0:start_pos+T, :]
|
||||
else:
|
||||
raw_k = k.reshape(B, T, self.n_kv_heads, self.head_dim)
|
||||
k = raw_k.transpose(1, 2)
|
||||
if self.qk_norm == self.head_dim: k = self.attn_k_norm(k)
|
||||
raw_v = raw_k if self.use_alternative_attention else self.attn_v(x_norm).reshape(B, T, self.n_kv_heads, self.head_dim)
|
||||
v = rms_norm_no_weight(raw_v, self.config.norm_eps).transpose(1, 2)
|
||||
k = apply_rope(k, self.freqs_cis[start_pos:start_pos+T])
|
||||
assigned_kv = Tensor(self.cache_kv.uop.after(self.cache_kv[:, :, :, start_pos:start_pos+T, :].uop.store(Tensor.stack(k, v).uop)))
|
||||
if self.store_full_length_kv: self.full_kv_cache = assigned_kv
|
||||
k = assigned_kv[0, :, :, 0:start_pos+T, :]
|
||||
v = assigned_kv[1, :, :, 0:start_pos+T, :]
|
||||
|
||||
mask = None
|
||||
if resolve(T != 1) or self.is_sliding:
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1)
|
||||
if self.is_sliding:
|
||||
mask = mask + Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).tril(start_pos-self.config.sliding_window)
|
||||
k, v = repeat_kv(k, self.config.n_heads // self.n_kv_heads), repeat_kv(v, self.config.n_heads // self.n_kv_heads)
|
||||
return self.attn_output((((q @ k.transpose(-1, -2)) + (mask if mask is not None else 0)).softmax(-1) @ v).transpose(1, 2).reshape(B, T, -1))
|
||||
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
q, k, v = self.attn_q(x), self.attn_k(x), self.attn_v(x)
|
||||
if self.config.qk_norm and self.config.qk_norm != self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
if self.qk_norm and self.qk_norm != self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
|
||||
B, T, _ = x.shape
|
||||
if self.config.attn_output_gate:
|
||||
qg = q.reshape(B, T, self.config.n_heads, 2, self.config.head_dim)
|
||||
q, gate = qg[:, :, :, 0, :], qg[:, :, :, 1, :].reshape(B, T, self.config.n_heads * self.config.head_dim)
|
||||
q = q.reshape(B, T, self.config.n_heads, self.config.head_dim).transpose(1, 2) # (B,H,T,Hd)
|
||||
k = k.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
v = v.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
if self.config.qk_norm == self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
qg = q.reshape(B, T, self.config.n_heads, 2, self.head_dim)
|
||||
q, gate = qg[:, :, :, 0, :], qg[:, :, :, 1, :].reshape(B, T, self.config.n_heads * self.head_dim)
|
||||
q = q.reshape(B, T, self.config.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
|
||||
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
if self.qk_norm == self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
|
||||
q = apply_rope(q[..., :self.config.rope_dim], self.freqs_cis[start_pos:start_pos+T]).cat(q[..., self.config.rope_dim:], dim=-1)
|
||||
k = apply_rope(k[..., :self.config.rope_dim], self.freqs_cis[start_pos:start_pos+T]).cat(k[..., self.config.rope_dim:], dim=-1)
|
||||
|
|
@ -247,8 +354,19 @@ class TransformerBlock(FFNBlock):
|
|||
def _init_state(self, x:Tensor):
|
||||
if not hasattr(self, "cache_kv"):
|
||||
# TODO: how is the dtype of this determined?
|
||||
self.cache_kv = Tensor.empty(2, x.shape[0], self.config.n_kv_heads, self.config.max_context, self.config.head_dim, device=x.device)
|
||||
self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta)
|
||||
self.cache_kv = Tensor.empty(2, x.shape[0], self.n_kv_heads, self.config.max_context, self.head_dim, device=x.device)
|
||||
self.freqs_cis = precompute_freqs_cis(self.head_dim if self.config.gemma4 else self.config.rope_dim, self.config.max_context, self.rope_theta)
|
||||
|
||||
def __call__(self, x: Tensor, start_pos: int|UOp, per_layer_input:Tensor|None=None, shared_kv_cache:Tensor|None=None):
|
||||
if not self.config.gemma4: return super().__call__(x, start_pos)
|
||||
self._init_state(x)
|
||||
@function(precompile=True, allow_implicit=True)
|
||||
def _run(x:Tensor, start_pos:int|UOp, per_layer_input:Tensor|None=None, shared_kv_cache:Tensor|None=None):
|
||||
h = x + self.post_attention_norm(self._attention(x, start_pos, shared_kv_cache))
|
||||
h = h + self.post_ffw_norm(self._feed_forward(h))
|
||||
if per_layer_input is not None: h = h + self.post_norm(self.proj(self.inp_gate(h).gelu() * per_layer_input))
|
||||
return (h * self.layer_output_scale.weight).contiguous()
|
||||
return _run(x, start_pos, per_layer_input, shared_kv_cache)
|
||||
|
||||
class MLATransformerBlock(FFNBlock):
|
||||
def __init__(self, config:TransformerConfig):
|
||||
|
|
@ -347,15 +465,48 @@ class GatedDeltaNetBlock(FFNBlock):
|
|||
|
||||
class Transformer:
|
||||
def __init__(self, config:TransformerConfig):
|
||||
dense_config = replace(config, num_experts=0, num_experts_per_tok=0, shared_expert_dim=0, hidden_dim=config.dense_hidden_dim or config.hidden_dim)
|
||||
if config.ssm: config = replace(config, qk_norm=config.head_dim)
|
||||
block_cls = MLATransformerBlock if config.kv_lora_rank > 0 else TransformerBlock
|
||||
self.blk:list[FFNBlock] = [GatedDeltaNetBlock(config, config.ssm) if config.ssm and (i+1) % config.full_attention_interval != 0 else
|
||||
block_cls(dense_config if i < config.leading_dense_blocks else config) for i in range(config.num_blocks)]
|
||||
self.config = config
|
||||
def layer_config(i:int) -> TransformerConfig:
|
||||
return replace(
|
||||
config,
|
||||
hidden_dim=config.hidden_dim[i] if isinstance(config.hidden_dim, tuple) else config.hidden_dim,
|
||||
n_kv_heads=config.n_kv_heads[i] if isinstance(config.n_kv_heads, tuple) else config.n_kv_heads,
|
||||
head_dim=config.head_dim[i] if isinstance(config.head_dim, tuple) else config.head_dim,
|
||||
rope_theta=config.rope_theta[i] if isinstance(config.rope_theta, tuple) else config.rope_theta,
|
||||
qk_norm=config.qk_norm[i] if isinstance(config.qk_norm, tuple) else config.qk_norm,
|
||||
sliding_window_pattern=(config.sliding_window_pattern[i],) if config.sliding_window_pattern else ())
|
||||
if config.gemma4:
|
||||
self.blk = [TransformerBlock(layer_config(i)) for i in range(config.num_blocks)]
|
||||
else:
|
||||
dense_config = replace(
|
||||
config, num_experts=0, num_experts_per_tok=0, shared_expert_dim=0,
|
||||
hidden_dim=config.dense_hidden_dim or config.hidden_dim)
|
||||
block_cls = MLATransformerBlock if config.kv_lora_rank > 0 else TransformerBlock
|
||||
self.blk:list[FFNBlock] = [GatedDeltaNetBlock(config, config.ssm) if config.ssm and (i+1) % config.full_attention_interval != 0 else
|
||||
block_cls(dense_config if i < config.leading_dense_blocks else config) for i in range(config.num_blocks)]
|
||||
self.token_embd = nn.Embedding(config.vocab_size, config.dim)
|
||||
self.output_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
|
||||
if config.per_layer_input_dim:
|
||||
self.per_layer_model_proj = nn.Linear(config.dim, config.num_blocks * config.per_layer_input_dim, bias=False)
|
||||
self.per_layer_proj_norm = nn.RMSNorm(config.per_layer_input_dim, config.norm_eps)
|
||||
self.per_layer_token_embd = nn.Embedding(config.vocab_size, config.num_blocks * config.per_layer_input_dim)
|
||||
self.max_context = config.max_context
|
||||
self.embed_scale = config.dim ** 0.5 if config.gemma4 else 1.0
|
||||
self.per_layer_embed_scale = config.per_layer_input_dim ** 0.5 if config.per_layer_input_dim else 1.0
|
||||
self.per_layer_input_scale = 2 ** -0.5
|
||||
self.per_layer_model_projection_scale = config.dim ** -0.5 if config.gemma4 else 1.0
|
||||
self.final_logit_softcap = config.final_logit_softcap
|
||||
if config.num_kv_shared_layers:
|
||||
first_shared = config.num_blocks - config.num_kv_shared_layers
|
||||
last_of_type = {
|
||||
False: max(i for i in range(first_shared) if not config.sliding_window_pattern[i]),
|
||||
True: max(i for i in range(first_shared) if config.sliding_window_pattern[i])}
|
||||
for idx, block in enumerate(self.blk[:first_shared]):
|
||||
if bool(config.sliding_window_pattern) and idx == last_of_type[config.sliding_window_pattern[idx]]: block.store_full_length_kv = True
|
||||
for idx, block in enumerate(self.blk[first_shared:], start=first_shared):
|
||||
block.shared_kv_src_idx = last_of_type[config.sliding_window_pattern[idx]]
|
||||
self.has_recurrent_block = any(isinstance(b, GatedDeltaNetBlock) for b in self.blk)
|
||||
self._cached_tokens: list[int] = []
|
||||
# we specialize the JIT for prefill and rollout
|
||||
|
|
@ -363,9 +514,23 @@ class Transformer:
|
|||
self.rollout_jit = TinyJit(self.forward)
|
||||
|
||||
def forward(self, tokens:Tensor, start_pos:int|UOp, temperature:Tensor) -> Tensor:
|
||||
x = self.token_embd(tokens).float() # (B, T, D)
|
||||
for block in self.blk: x = block(x, start_pos)
|
||||
x = self.token_embd(tokens).float() * self.embed_scale # (B, T, D)
|
||||
if not self.config.gemma4:
|
||||
for block in self.blk: x = block(x, start_pos)
|
||||
else:
|
||||
per_layer_inputs = None
|
||||
if hasattr(self, 'per_layer_token_embd'):
|
||||
B, T, _ = x.shape
|
||||
per_layer_inputs = self.per_layer_proj_norm(
|
||||
(self.per_layer_model_proj(x) * self.per_layer_model_projection_scale).reshape(B, T, len(self.blk), -1))
|
||||
per_layer_inputs = (
|
||||
per_layer_inputs + self.per_layer_token_embd(tokens).float().reshape(B, T, len(self.blk), -1) * self.per_layer_embed_scale
|
||||
) * self.per_layer_input_scale
|
||||
for i, block in enumerate(self.blk):
|
||||
shared_kv_cache = self.blk[block.shared_kv_src_idx].full_kv_cache if block.shared_kv_src_idx is not None else None
|
||||
x = block(x, start_pos, None if per_layer_inputs is None else per_layer_inputs[:,:,i,:], shared_kv_cache)
|
||||
logits = self.output(self.output_norm(x))[:, -1, :]
|
||||
if self.final_logit_softcap: logits = (logits / self.final_logit_softcap).tanh() * self.final_logit_softcap
|
||||
# Gumbel-max trick: argmax(logits/temp - log(-log(uniform))) is equivalent to sampling from softmax(logits/temp)
|
||||
return (logits / temperature.maximum(1e-12) - (Tensor.rand_like(logits).maximum(1e-12).log().neg()).log()).argmax(-1, keepdim=True)
|
||||
|
||||
|
|
@ -407,17 +572,36 @@ class Transformer:
|
|||
state_dict[name] = w.rearrange("n (h two) d -> n (two h) d", two=2).reshape(-1, w.shape[-1])
|
||||
elif kv_lora_rank and 'attn_kv_a_mqa.weight' in name:
|
||||
state_dict[name] = state_dict[name][:kv_lora_rank].cat(state_dict[name][kv_lora_rank:].rearrange("(h two) d -> (two h) d", two=2), dim=0)
|
||||
hidden_dim = kv[f'{arch}.feed_forward_length'] if arch == 'gemma4' else \
|
||||
kv.get(f'{arch}.expert_feed_forward_length', kv.get(f'{arch}.feed_forward_length', 0))
|
||||
if arch == 'gemma4' and isinstance(hidden_dim, list): hidden_dim = tuple(hidden_dim)
|
||||
if arch == 'gemma4':
|
||||
sliding_window_pattern = tuple(kv[f'{arch}.attention.sliding_window_pattern'])
|
||||
n_kv_heads = tuple(n_kv_heads) if isinstance(n_kv_heads, list) else n_kv_heads
|
||||
head_dim = tuple(
|
||||
kv[f'{arch}.attention.key_length_swa'] if is_sliding else kv[f'{arch}.attention.key_length']
|
||||
for is_sliding in sliding_window_pattern)
|
||||
rope_theta = tuple(
|
||||
kv.get(f'{arch}.rope.freq_base_swa', kv[f'{arch}.rope.freq_base']) if is_sliding else kv[f'{arch}.rope.freq_base']
|
||||
for is_sliding in sliding_window_pattern)
|
||||
else:
|
||||
sliding_window_pattern = ()
|
||||
rope_theta = kv[f'{arch}.rope.freq_base']
|
||||
|
||||
config = TransformerConfig(
|
||||
num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
|
||||
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv.get(f'{arch}.feed_forward_length', 0)),
|
||||
hidden_dim=hidden_dim,
|
||||
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
|
||||
vocab_size=len(kv['tokenizer.ggml.tokens']),
|
||||
head_dim=head_dim,
|
||||
rope_theta=kv[f'{arch}.rope.freq_base'],
|
||||
rope_theta=rope_theta,
|
||||
rope_dim=rope_dim,
|
||||
v_head_dim=kv.get(f'{arch}.attention.value_length_mla', kv.get(f'{arch}.attention.value_length', head_dim)),
|
||||
v_head_dim=kv.get(
|
||||
f'{arch}.attention.value_length_mla',
|
||||
kv.get(f'{arch}.attention.value_length', head_dim if isinstance(head_dim, int) else head_dim[0])),
|
||||
max_context=max_context,
|
||||
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
|
||||
qk_norm=head_dim if arch == 'gemma4' else (
|
||||
int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0),
|
||||
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0),
|
||||
norm_topk_prob=kv.get(f'{arch}.expert_weights_norm', arch in ('qwen3moe', 'qwen35moe')),
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
|
|
@ -427,8 +611,17 @@ class Transformer:
|
|||
kv.get(f'{arch}.expert_shared_count', 0) * kv.get(f'{arch}.expert_feed_forward_length', 0)),
|
||||
shared_expert_gate=f"blk.{kv.get(f'{arch}.leading_dense_block_count', 0)}.ffn_gate_inp_shexp.weight" in state_dict,
|
||||
dense_hidden_dim=kv.get(f'{arch}.feed_forward_length', 0) if kv.get(f'{arch}.leading_dense_block_count', 0) else 0,
|
||||
routed_scaling_factor=kv.get(f'{arch}.expert_weights_scale', 1.0), attn_output_gate=arch in ('qwen35', 'qwen35moe'), ssm=ssm,
|
||||
full_attention_interval=kv.get(f'{arch}.full_attention_interval', 0))
|
||||
routed_scaling_factor=kv.get(f'{arch}.expert_weights_scale', 1.0),
|
||||
full_attention_interval=kv.get(f'{arch}.full_attention_interval', 0),
|
||||
attn_output_gate=arch in ('qwen35', 'qwen35moe'),
|
||||
ssm=ssm,
|
||||
sliding_window=kv.get(f'{arch}.attention.sliding_window', 0),
|
||||
sliding_window_pattern=sliding_window_pattern,
|
||||
per_layer_input_dim=kv.get(f'{arch}.embedding_length_per_layer_input', 0),
|
||||
final_logit_softcap=kv.get(f'{arch}.final_logit_softcapping', 0.0),
|
||||
num_kv_shared_layers=kv.get(f'{arch}.attention.shared_kv_layers', 0),
|
||||
gemma4=arch == 'gemma4',
|
||||
expert_hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', 0))
|
||||
model = Transformer(config)
|
||||
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
||||
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
||||
|
|
@ -478,6 +671,8 @@ models = {
|
|||
"qwen3.5:9b": "https://huggingface.co/unsloth/Qwen3.5-9B-GGUF/resolve/main/Qwen3.5-9B-Q4_K_M.gguf",
|
||||
"qwen3.5:27b": "https://huggingface.co/unsloth/Qwen3.5-27B-GGUF/resolve/main/Qwen3.5-27B-Q4_K_M.gguf",
|
||||
"qwen3.5:35b-a3b": "https://huggingface.co/unsloth/Qwen3.5-35B-A3B-GGUF/resolve/main/Qwen3.5-35B-A3B-Q4_K_M.gguf",
|
||||
"gemma4:e2b-q4": "https://huggingface.co/unsloth/gemma-4-E2B-it-GGUF/resolve/main/gemma-4-E2B-it-Q4_K_M.gguf",
|
||||
"gemma4:26b-a4b-q4": "https://huggingface.co/unsloth/gemma-4-26B-A4B-it-GGUF/resolve/main/gemma-4-26B-A4B-it-UD-Q4_K_M.gguf",
|
||||
"olmoe": "https://huggingface.co/allenai/OLMoE-1B-7B-0924-Instruct-GGUF/resolve/main/olmoe-1b-7b-0924-instruct-q4_k_m.gguf",
|
||||
"moonlight": "https://huggingface.co/gabriellarson/Moonlight-16B-A3B-Instruct-GGUF/resolve/main/Moonlight-16B-A3B-Instruct-Q4_K_M.gguf",
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue