mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
3 commits
master
...
olmo3_supp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
72fccb5f5c | ||
|
|
6077327b80 | ||
|
|
79079ca3a9 |
5 changed files with 235 additions and 105 deletions
100
test/null/test_llm_chat.py
Normal file
100
test/null/test_llm_chat.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
import unittest
|
||||
from tinygrad.llm.cli import SimpleTokenizer, Chat
|
||||
|
||||
class TestChatSimple(unittest.TestCase):
|
||||
def test_tekken_preset(self):
|
||||
# Tekken (Mistral): role(user)=[INST], role(assistant)=[], end_turn=[/INST].
|
||||
kv = {
|
||||
"tokenizer.ggml.tokens": ["<unk>", "<s>", "</s>", "[INST]", "[/INST]", "hello", "sure"],
|
||||
"tokenizer.ggml.token_type": [3, 3, 3, 3, 3, 1, 1],
|
||||
"tokenizer.ggml.pre": "tekken",
|
||||
"tokenizer.ggml.eos_token_id": 2,
|
||||
}
|
||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||
chat = Chat.from_gguf_kv(kv, tok)
|
||||
# single user turn: [INST] hello [/INST]
|
||||
self.assertEqual(chat.apply([{"role": "user", "content": "hello"}]), [3, 5, 4])
|
||||
# user + assistant: [INST] hello [/INST] sure [/INST]
|
||||
self.assertEqual(chat.apply([{"role": "user", "content": "hello"}, {"role": "assistant", "content": "sure"}]),
|
||||
[3, 5, 4, 6, 4])
|
||||
# add_generation_prompt on tekken appends role("assistant") which is []
|
||||
self.assertEqual(chat.apply([{"role": "user", "content": "hello"}], add_generation_prompt=True), [3, 5, 4])
|
||||
|
||||
def test_is_end_basic(self):
|
||||
kv = {"tokenizer.ggml.tokens": ["<unk>", "<eos>"], "tokenizer.ggml.token_type": [3, 3],
|
||||
"tokenizer.ggml.pre": "llama3", "tokenizer.ggml.eos_token_id": 1}
|
||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||
chat = Chat(tok, preset="llama3")
|
||||
self.assertTrue(chat.is_end(1))
|
||||
self.assertFalse(chat.is_end(0))
|
||||
|
||||
def test_olmo2_simple_mode(self):
|
||||
# OLMo 2: pre='dbrx' (unsupported) but arch override maps it to qwen2-style chat with <|im_end|> as turn-end.
|
||||
# tokenizer's eos_id must be untouched; Chat widens stop_ids and uses <|im_end|> for end_turn.
|
||||
kv = {
|
||||
"tokenizer.ggml.tokens": ["<|endoftext|>", "hello", "<|im_end|>", "<|im_start|>"],
|
||||
"tokenizer.ggml.token_type": [3, 1, 3, 3],
|
||||
"tokenizer.ggml.pre": "dbrx",
|
||||
"tokenizer.ggml.eos_token_id": 0,
|
||||
"general.architecture": "olmo2",
|
||||
}
|
||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||
chat = Chat.from_gguf_kv(kv, tok)
|
||||
self.assertEqual(tok.eos_id, 0) # raw GGUF eos (<|endoftext|>) untouched
|
||||
self.assertEqual(chat.preset, "qwen2") # arch-overridden preset
|
||||
self.assertEqual(chat.turn_end_id, 2) # <|im_end|>
|
||||
self.assertTrue(chat.is_end(0)) # <|endoftext|>
|
||||
self.assertTrue(chat.is_end(2)) # <|im_end|>
|
||||
self.assertFalse(chat.is_end(1))
|
||||
|
||||
def test_assistant_prefill_no_end_turn(self):
|
||||
# continue_final_message should drop the trailing end_turn after the prefill assistant message.
|
||||
kv = {
|
||||
"tokenizer.ggml.tokens": ["<unk>", "<s>", "</s>", "[INST]", "[/INST]", "hello", "sure"],
|
||||
"tokenizer.ggml.token_type": [3, 3, 3, 3, 3, 1, 1],
|
||||
"tokenizer.ggml.pre": "tekken",
|
||||
"tokenizer.ggml.eos_token_id": 2,
|
||||
}
|
||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||
chat = Chat.from_gguf_kv(kv, tok)
|
||||
msgs = [{"role": "user", "content": "hello"}, {"role": "assistant", "content": "sure"}]
|
||||
self.assertEqual(chat.apply(msgs), [3, 5, 4, 6, 4])
|
||||
self.assertEqual(chat.apply(msgs, continue_final_message=True), [3, 5, 4, 6])
|
||||
|
||||
def test_unsupported_preset_without_jinja(self):
|
||||
kv = {"tokenizer.ggml.tokens": ["<unk>"], "tokenizer.ggml.token_type": [3], "tokenizer.ggml.pre": "dbrx"}
|
||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||
with self.assertRaises(ValueError):
|
||||
Chat.from_gguf_kv(kv, tok) # "dbrx" not in _PRESETS and not olmo2 arch
|
||||
# works with use_jinja=True (no simple-preset check), as long as a chat_template is provided
|
||||
kv2 = {**kv, "tokenizer.chat_template": "x"}
|
||||
tok2 = SimpleTokenizer.from_gguf_kv(kv2)
|
||||
Chat.from_gguf_kv(kv2, tok2, use_jinja=True)
|
||||
|
||||
class TestChatJinja(unittest.TestCase):
|
||||
def setUp(self):
|
||||
try: import jinja2 # noqa: F401
|
||||
except ImportError: self.skipTest("jinja2 not installed")
|
||||
|
||||
def test_render_simple_template(self):
|
||||
template = ("{%- for m in messages %}{%- if m.role == 'user' %}{{- '[INST]' + m.content + '[/INST]' }}"
|
||||
"{%- elif m.role == 'assistant' %}{{- m.content + '</s>' }}{%- endif %}{%- endfor %}")
|
||||
kv = {
|
||||
"tokenizer.ggml.tokens": ["<unk>", "<s>", "</s>", "[INST]", "[/INST]", "hello"],
|
||||
"tokenizer.ggml.token_type": [3, 3, 3, 3, 3, 1],
|
||||
"tokenizer.ggml.pre": "tekken",
|
||||
"tokenizer.ggml.eos_token_id": 2,
|
||||
"tokenizer.chat_template": template,
|
||||
}
|
||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||
chat = Chat.from_gguf_kv(kv, tok, use_jinja=True)
|
||||
self.assertEqual(chat.apply([{"role": "user", "content": "hello"}]), [3, 5, 4])
|
||||
|
||||
def test_jinja_requires_template(self):
|
||||
kv = {"tokenizer.ggml.tokens": ["<unk>"], "tokenizer.ggml.token_type": [3], "tokenizer.ggml.pre": "llama3"}
|
||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||
with self.assertRaises(ValueError):
|
||||
Chat.from_gguf_kv(kv, tok, use_jinja=True) # no tokenizer.chat_template
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -7,17 +7,17 @@ class TestLLMServer(unittest.TestCase):
|
|||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mock_tok = Mock()
|
||||
cls.mock_tok.role = Mock(return_value=[100, 101])
|
||||
cls.mock_tok.encode = Mock(return_value=[200, 201, 202])
|
||||
cls.mock_tok.decode = Mock(return_value="Hello")
|
||||
cls.mock_tok.stream_decoder = Mock(return_value=lambda tid=None: "Hello" if tid is not None else "")
|
||||
cls.mock_tok.end_turn = Mock(return_value=[998])
|
||||
cls.mock_tok.prefix = Mock(return_value=[1])
|
||||
cls.mock_tok.preset = "llama3"
|
||||
cls.mock_tok.bos_id = 1
|
||||
cls.mock_tok.eos_id = 999
|
||||
cls.mock_tok.eot_id = None
|
||||
cls.mock_tok.is_end = Mock(side_effect=lambda tid: tid in (999,))
|
||||
|
||||
cls.mock_chat = Mock()
|
||||
cls.mock_chat.tok = cls.mock_tok
|
||||
cls.mock_chat.apply = Mock(return_value=[1, 100, 101, 200, 201, 202, 998, 100, 101])
|
||||
cls.mock_chat.is_end = Mock(side_effect=lambda tid: tid in (999,))
|
||||
|
||||
cls.mock_model = Mock()
|
||||
cls.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 999]))
|
||||
|
|
@ -25,7 +25,7 @@ class TestLLMServer(unittest.TestCase):
|
|||
|
||||
from tinygrad.llm.cli import LLMServer
|
||||
|
||||
cls.server = LLMServer(('127.0.0.1', 0), cls.mock_model, "test-model", cls.mock_tok)
|
||||
cls.server = LLMServer(('127.0.0.1', 0), cls.mock_model, "test-model", cls.mock_chat)
|
||||
cls.port = cls.server.server_address[1]
|
||||
cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True)
|
||||
cls.server_thread.start()
|
||||
|
|
@ -150,35 +150,24 @@ class TestLLMServer(unittest.TestCase):
|
|||
self.assertEqual(resp.usage.completion_tokens, 2)
|
||||
|
||||
def test_assistant_prefill(self):
|
||||
"""Last assistant message should be treated as prefill (not a completed turn)."""
|
||||
"""Last assistant message should be treated as prefill (continue_final_message=True)."""
|
||||
self.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 999]))
|
||||
captured_ids = []
|
||||
orig_generate = self.mock_model.generate.side_effect
|
||||
def capture_generate(ids, **kwargs):
|
||||
captured_ids.extend(ids)
|
||||
return orig_generate(ids, **kwargs)
|
||||
self.mock_model.generate = Mock(side_effect=capture_generate)
|
||||
|
||||
self.mock_chat.apply.reset_mock()
|
||||
resp = self.client.chat.completions.create(
|
||||
model="test", messages=[
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Sure"}
|
||||
], stream=False
|
||||
)
|
||||
# prefill tokens should be in ids: role("assistant") + encode("Sure") but NO end_turn after it
|
||||
# and NO extra role("assistant") appended
|
||||
role_tokens = self.mock_tok.role.call_args_list
|
||||
# last role() call should be for "assistant" (the prefill message), not an extra one
|
||||
self.assertEqual(role_tokens[-1], unittest.mock.call("assistant"))
|
||||
# end_turn should be called once less than role() — the prefill assistant msg doesn't get end_turn
|
||||
self.assertEqual(self.mock_tok.end_turn.call_count, self.mock_tok.role.call_count - 1)
|
||||
call = self.mock_chat.apply.call_args
|
||||
self.assertTrue(call.kwargs["continue_final_message"])
|
||||
self.assertFalse(call.kwargs["add_generation_prompt"])
|
||||
self.assertIsNotNone(resp.choices[0].message.content)
|
||||
|
||||
def test_assistant_prefill_not_last(self):
|
||||
"""Assistant message that's NOT last should be a normal completed turn."""
|
||||
"""Assistant message that's NOT last should be a normal completed turn (add_generation_prompt=True)."""
|
||||
self.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 999]))
|
||||
self.mock_tok.role.reset_mock()
|
||||
self.mock_tok.end_turn.reset_mock()
|
||||
self.mock_chat.apply.reset_mock()
|
||||
self.client.chat.completions.create(
|
||||
model="test", messages=[
|
||||
{"role": "user", "content": "Hello"},
|
||||
|
|
@ -186,11 +175,9 @@ class TestLLMServer(unittest.TestCase):
|
|||
{"role": "user", "content": "Continue"}
|
||||
], stream=False
|
||||
)
|
||||
# all messages get end_turn, plus an extra role("assistant") at the end
|
||||
# roles: user, assistant, user, assistant(generation prompt) = 4 role calls
|
||||
# end_turns: user, assistant, user = 3 end_turn calls (one per message)
|
||||
self.assertEqual(self.mock_tok.end_turn.call_count, 3)
|
||||
self.assertEqual(self.mock_tok.role.call_count, 4)
|
||||
call = self.mock_chat.apply.call_args
|
||||
self.assertFalse(call.kwargs["continue_final_message"])
|
||||
self.assertTrue(call.kwargs["add_generation_prompt"])
|
||||
|
||||
def test_models_endpoint(self):
|
||||
import requests as req
|
||||
|
|
|
|||
|
|
@ -46,19 +46,6 @@ class TestLLMTokenizer(unittest.TestCase):
|
|||
def test_llama_repeat(self): self._test_coding(self.llama_tok, "00000000000000000", [ 931, 931, 931, 931, 931, 410 ])
|
||||
def test_llama_pat(self): self._test_coding(self.llama_tok, "today\n \n", [ 31213, 14211 ])
|
||||
|
||||
def test_tekken_from_gguf_kv(self):
|
||||
kv = {
|
||||
"tokenizer.ggml.tokens": ["<unk>", "<s>", "</s>", "[INST]", "[/INST]", "hello"],
|
||||
"tokenizer.ggml.token_type": [3, 3, 3, 3, 3, 1],
|
||||
"tokenizer.ggml.pre": "tekken",
|
||||
"tokenizer.ggml.eos_token_id": 2,
|
||||
}
|
||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||
self.assertEqual(tok.role("user"), [3])
|
||||
self.assertEqual(tok.encode("hello"), [5])
|
||||
self.assertEqual(tok.end_turn(), [4])
|
||||
self.assertEqual(tok.role("assistant"), [])
|
||||
|
||||
def test_stream_decoder(self):
|
||||
"""stream_decoder buffers incomplete UTF-8: token 25677 has 3/4 of emoji, token 138 completes it."""
|
||||
bs = [*range(33, 127), *range(161, 173), *range(174, 256)]
|
||||
|
|
|
|||
|
|
@ -6,11 +6,8 @@ from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
|
|||
from tinygrad.llm.model import Transformer
|
||||
|
||||
class SimpleTokenizer:
|
||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3",
|
||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int],
|
||||
bos_id:int|None=None, eos_id:int=0, eot_id:int|None=None):
|
||||
preset = {"qwen35":"qwen2","qwen35moe":"qwen2"}.get(preset, preset)
|
||||
if preset not in ("llama3","llama-v3","llama-bpe","qwen2","olmo","kimi-k2","tekken","glm4"):
|
||||
raise ValueError(f"Invalid tokenizer preset '{preset}'")
|
||||
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
||||
bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves
|
||||
self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
|
||||
|
|
@ -26,7 +23,6 @@ class SimpleTokenizer:
|
|||
self._normal_tokens = {bytes(self._byte_decoder[c] for c in tok): tid for tok, tid in normal_tokens.items()}
|
||||
self._special_tokens = special_tokens
|
||||
self._tok2bytes = {tid: tok for tok, tid in self._normal_tokens.items()} | {tid: tok.encode() for tok, tid in self._special_tokens.items()}
|
||||
self.preset = preset
|
||||
self.bos_id, self.eos_id, self.eot_id = bos_id, eos_id, eot_id
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -34,7 +30,7 @@ class SimpleTokenizer:
|
|||
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
|
||||
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
|
||||
normal_tokens, special_tokens = partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
|
||||
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens), kv["tokenizer.ggml.pre"],
|
||||
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens),
|
||||
bos_id=kv.get('tokenizer.ggml.bos_token_id') if kv.get('tokenizer.ggml.add_bos_token', True) else None,
|
||||
eos_id=kv.get('tokenizer.ggml.eos_token_id', 0), eot_id=kv.get('tokenizer.ggml.eot_token_id'))
|
||||
|
||||
|
|
@ -63,26 +59,87 @@ class SimpleTokenizer:
|
|||
dec = codecs.getincrementaldecoder('utf-8')('replace')
|
||||
def _decode(tid:int|None=None) -> str: return dec.decode(self._tok2bytes[tid]) if tid is not None else dec.decode(b'', final=True)
|
||||
return _decode
|
||||
def role(self, role:str):
|
||||
if self.preset == 'olmo': return self.encode("<|" + role + "|>\n") # OLMoE Instruct format
|
||||
if self.preset == 'kimi-k2': return self.encode("<|im_" + role + "|>" + role + "<|im_middle|>")
|
||||
if self.preset == 'qwen2': return self.encode("<|im_start|>" + role + "\n")
|
||||
if self.preset == 'glm4': return self.encode("<|" + role + "|>")
|
||||
if self.preset == 'tekken':
|
||||
if role == 'user': return self.encode("[INST]")
|
||||
if role == 'assistant': return []
|
||||
raise ValueError(f"Unsupported role '{role}' for tokenizer preset '{self.preset}'")
|
||||
return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
||||
def end_turn(self):
|
||||
if self.preset == 'olmo': return self.encode("\n")
|
||||
if self.preset == 'kimi-k2': return [self.eos_id]
|
||||
if self.preset == 'qwen2': return [self.eos_id] + self.encode("\n")
|
||||
if self.preset == 'glm4': return []
|
||||
if self.preset == 'tekken': return self.encode("[/INST]")
|
||||
return [self.eos_id]
|
||||
def prefix(self) -> list[int]:
|
||||
return ([] if self.bos_id is None else [self.bos_id]) + (self.encode("<sop>") if self.preset == 'glm4' else [])
|
||||
def is_end(self, token_id:int) -> bool: return token_id in (self.eos_id, self.eot_id)
|
||||
|
||||
def _flatten_content(c) -> str:
|
||||
return c if isinstance(c, str) else "".join(p["text"] for p in c if p.get("type") == "text")
|
||||
|
||||
class Chat:
|
||||
"""Formats messages into tokens for a given model.
|
||||
|
||||
Two modes:
|
||||
- default (simple): uses a small hard-coded preset dispatch keyed on `tokenizer.ggml.pre` for the formatting.
|
||||
Covers llama3/qwen2/olmo/kimi-k2/tekken/glm4 chat formats.
|
||||
- `use_jinja=True`: renders the GGUF's `tokenizer.chat_template` with the real `jinja2` package.
|
||||
Needed for templates using features outside the simple preset set (e.g. Qwen 3.5's macros).
|
||||
"""
|
||||
_PRESETS = ("llama3", "llama-v3", "llama-bpe", "qwen2", "olmo", "kimi-k2", "tekken", "glm4")
|
||||
|
||||
def __init__(self, tok:SimpleTokenizer, template:str|None=None, preset:str="llama3",
|
||||
use_jinja:bool=False, extra_stop_ids:typing.Iterable[int]=(), turn_end_id:int|None=None):
|
||||
self.tok, self.template, self.use_jinja = tok, template, use_jinja
|
||||
self.preset = {"qwen35":"qwen2", "qwen35moe":"qwen2"}.get(preset, preset)
|
||||
self.turn_end_id = turn_end_id if turn_end_id is not None else tok.eos_id
|
||||
self.stop_ids: set[int] = {x for x in (tok.eos_id, tok.eot_id, self.turn_end_id, *extra_stop_ids) if x is not None}
|
||||
if use_jinja:
|
||||
if template is None: raise ValueError("use_jinja=True requires tokenizer.chat_template in the GGUF")
|
||||
elif self.preset not in self._PRESETS:
|
||||
raise ValueError(f"unsupported tokenizer preset {self.preset!r}; pass use_jinja=True to use the GGUF chat_template instead")
|
||||
|
||||
@staticmethod
|
||||
def from_gguf_kv(kv:dict, tok:SimpleTokenizer, use_jinja:bool=False) -> 'Chat':
|
||||
preset = kv.get('tokenizer.ggml.pre', 'llama3')
|
||||
extra: list[int] = []
|
||||
turn_end_id: int|None = None
|
||||
# OLMo 2: tokenizer.ggml.pre is "dbrx" but the chat format is qwen2-style (<|im_start|>.../<|im_end|>).
|
||||
# <|im_end|> terminates turns (not <|endoftext|>) but both should stop generation.
|
||||
if kv.get('general.architecture') == 'olmo2':
|
||||
preset = 'qwen2'
|
||||
im_end = next((i for i,t in enumerate(kv['tokenizer.ggml.tokens']) if t == '<|im_end|>'), None)
|
||||
if im_end is not None: extra.append(im_end); turn_end_id = im_end
|
||||
return Chat(tok, kv.get('tokenizer.chat_template'), preset, use_jinja, extra, turn_end_id)
|
||||
|
||||
def is_end(self, token_id:int) -> bool: return token_id in self.stop_ids
|
||||
|
||||
def apply(self, messages:list[dict], add_generation_prompt:bool=False, continue_final_message:bool=False) -> list[int]:
|
||||
return (self._apply_jinja if self.use_jinja else self._apply_simple)(messages, add_generation_prompt, continue_final_message)
|
||||
|
||||
def _apply_jinja(self, messages, add_generation_prompt, continue_final_message):
|
||||
try: import jinja2
|
||||
except ImportError as e: raise RuntimeError("use_jinja=True requires the jinja2 package: pip install jinja2") from e
|
||||
def tok_str(tid): return self.tok._tok2bytes[tid].decode(errors='replace') if tid is not None and tid in self.tok._tok2bytes else ''
|
||||
bos, eos, t = tok_str(self.tok.bos_id), tok_str(self.tok.eos_id), jinja2.Template(self.template)
|
||||
if continue_final_message:
|
||||
assert messages and messages[-1]["role"] == "assistant", "continue_final_message requires trailing assistant message"
|
||||
head = t.render(messages=messages[:-1], add_generation_prompt=True, bos_token=bos, eos_token=eos)
|
||||
return self.tok.encode(head + _flatten_content(messages[-1]["content"]))
|
||||
return self.tok.encode(t.render(messages=messages, add_generation_prompt=add_generation_prompt, bos_token=bos, eos_token=eos))
|
||||
|
||||
def _apply_simple(self, messages, add_generation_prompt, continue_final_message):
|
||||
tok, p, e = self.tok, self.preset, self.turn_end_id
|
||||
|
||||
# role header template (role name interpolated as {0}); llama3 is the default
|
||||
role_tmpl = {'qwen2': "<|im_start|>{0}\n",
|
||||
'olmo': "<|{0}|>\n",
|
||||
'kimi-k2': "<|im_{0}|>{0}<|im_middle|>",
|
||||
'glm4': "<|{0}|>"}.get(p, "<|start_header_id|>{0}<|end_header_id|>\n\n")
|
||||
def role(r): # tekken is asymmetric: empty header for assistant
|
||||
if p == 'tekken': return tok.encode("[INST]") if r == "user" else []
|
||||
return tok.encode(role_tmpl.format(r))
|
||||
|
||||
# end-of-turn token ids; llama3 is the default
|
||||
if p == 'qwen2': end_turn = [e, *tok.encode("\n")]
|
||||
elif p == 'olmo': end_turn = tok.encode("\n")
|
||||
elif p == 'glm4': end_turn = []
|
||||
elif p == 'tekken': end_turn = tok.encode("[/INST]")
|
||||
else: end_turn = [e] # llama3, kimi-k2
|
||||
|
||||
prefill = continue_final_message and messages and messages[-1]["role"] == "assistant"
|
||||
ids = ([tok.bos_id] if tok.bos_id is not None else []) + (tok.encode("<sop>") if p == 'glm4' else [])
|
||||
for i, m in enumerate(messages):
|
||||
ids += role(m["role"]) + tok.encode(_flatten_content(m["content"]))
|
||||
if not prefill or i < len(messages) - 1: ids += end_turn
|
||||
if add_generation_prompt and not prefill: ids += role("assistant")
|
||||
return ids
|
||||
|
||||
models = {
|
||||
"llama3.2:1b": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
|
||||
|
|
@ -113,7 +170,7 @@ class Handler(HTTPRequestHandler):
|
|||
if self.path == "/v1/models": self.send_data(json.dumps({"object":"list","data":[{"id":self.server.model_name,"object":"model"}]}).encode())
|
||||
else: self.send_data((pathlib.Path(__file__).parent / "chat.html").read_bytes(), content_type="text/html")
|
||||
def run_model(self, ids:list[int], model_name:str, include_usage=False, max_tokens:int|None=None, temperature:float=0.0):
|
||||
model, tok = self.server.model, self.server.tok
|
||||
model, chat = self.server.model, self.server.chat
|
||||
cache_start_pos = model.get_start_pos(ids)
|
||||
stderr_log(f"{self.path} {colored('--', 'BLACK')} "
|
||||
f"in:{colored(f'{cache_start_pos:5d}', 'green')} +{len(ids)-cache_start_pos:5d} {colored('--', 'BLACK')} ")
|
||||
|
|
@ -122,10 +179,10 @@ class Handler(HTTPRequestHandler):
|
|||
out: list[int] = []
|
||||
finish_reason = "stop"
|
||||
st = time.perf_counter()
|
||||
dec = tok.stream_decoder()
|
||||
dec = chat.tok.stream_decoder()
|
||||
for next_id in model.generate(ids, temperature=temperature):
|
||||
if len(out) == 0: stderr_log(f"prefill:{(len(ids)-cache_start_pos)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ")
|
||||
if tok.is_end(next_id): break
|
||||
if chat.is_end(next_id): break
|
||||
out.append(next_id)
|
||||
yield {"choices": [{"index":0, "delta":{"content":dec(next_id)}, "finish_reason":None}], **tmpl}
|
||||
if max_tokens is not None and len(out) >= max_tokens:
|
||||
|
|
@ -140,25 +197,15 @@ class Handler(HTTPRequestHandler):
|
|||
f"out:{len(out):5d} {colored('--', 'BLACK')} total:{et-st:6.2f}s\n")
|
||||
|
||||
def do_POST(self):
|
||||
tok = self.server.tok
|
||||
chat = self.server.chat
|
||||
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
|
||||
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
|
||||
if DEBUG >= 1: print(json.dumps(body, indent=2))
|
||||
if self.path == "/v1/chat/completions":
|
||||
# extract tokens, last assistant message is treated as prefill
|
||||
ids: list[int] = tok.prefix()
|
||||
for i, msg in enumerate(body["messages"]):
|
||||
ids += tok.role(msg["role"])
|
||||
content = msg["content"]
|
||||
if isinstance(content, str): ids += tok.encode(content)
|
||||
elif isinstance(content, list):
|
||||
for c in content:
|
||||
if c["type"] == "text": ids += tok.encode(c["text"])
|
||||
else: raise RuntimeError(f"unhandled type: {c['type']}")
|
||||
else: raise RuntimeError(f"unknown content type: {type(content)}")
|
||||
if msg["role"] == "assistant" and i == len(body["messages"]) - 1: break
|
||||
ids += tok.end_turn()
|
||||
else: ids += tok.role("assistant")
|
||||
messages = [{"role": m["role"], "content": _flatten_content(m["content"])} for m in body["messages"]]
|
||||
prefill = bool(messages) and messages[-1]["role"] == "assistant"
|
||||
ids = chat.apply(messages, add_generation_prompt=not prefill, continue_final_message=prefill)
|
||||
|
||||
# reply
|
||||
max_tokens = body.get("max_completion_tokens") or body.get("max_tokens")
|
||||
|
|
@ -176,8 +223,8 @@ class Handler(HTTPRequestHandler):
|
|||
raise RuntimeError(f"unhandled path {self.path}")
|
||||
|
||||
class LLMServer(TCPServerWithReuse):
|
||||
def __init__(self, server_address:tuple, model:Transformer, model_name:str, tok:SimpleTokenizer):
|
||||
self.model, self.model_name, self.tok = model, model_name, tok
|
||||
def __init__(self, server_address:tuple, model:Transformer, model_name:str, chat:Chat):
|
||||
self.model, self.model_name, self.chat = model, model_name, chat
|
||||
super().__init__(server_address, Handler)
|
||||
|
||||
def main():
|
||||
|
|
@ -187,6 +234,7 @@ def main():
|
|||
parser.add_argument("--serve", nargs='?', type=int, const=8000, metavar="PORT", help="Run OpenAI compatible API (optional port, default 8000)")
|
||||
parser.add_argument("--warmup", action="store_true", help="warmup the JIT")
|
||||
parser.add_argument("--benchmark", nargs='?', type=int, const=20, metavar="COUNT", help="Benchmark tok/s (optional count, default 20)")
|
||||
parser.add_argument("--jinja", action="store_true", help="Render the GGUF chat_template with the real jinja2 package (needs `pip install jinja2`)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# load the model
|
||||
|
|
@ -196,8 +244,9 @@ def main():
|
|||
print(f"using model \"{model_name}\" with {raw_model.nbytes():,} bytes and {sum(x.numel() for x in nn.state.get_parameters(model)):,} params")
|
||||
del raw_model
|
||||
|
||||
# get tokenizer
|
||||
# get tokenizer and chat formatter
|
||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||
chat = Chat.from_gguf_kv(kv, tok, use_jinja=args.jinja)
|
||||
|
||||
# warmup the JIT
|
||||
if args.warmup or args.serve:
|
||||
|
|
@ -206,7 +255,7 @@ def main():
|
|||
for _ in range(2): list(zip(range(2), model.generate([0])))
|
||||
|
||||
# start server
|
||||
if args.serve: LLMServer(('', args.serve), model, model_name, tok).serve_forever()
|
||||
if args.serve: LLMServer(('', args.serve), model, model_name, chat).serve_forever()
|
||||
|
||||
# do benchmark
|
||||
if args.benchmark is not None:
|
||||
|
|
@ -218,17 +267,19 @@ def main():
|
|||
tok.decode(toks).replace("\n", "\\n")): next(gen)
|
||||
exit(0)
|
||||
|
||||
# interactive chat
|
||||
ids: list[int] = tok.prefix()
|
||||
# interactive chat (falls back to pure completion when the GGUF has no chat template)
|
||||
messages: list[dict] = []
|
||||
while 1:
|
||||
try:
|
||||
ids += tok.role("user") + tok.encode(input('>>> ')) + tok.end_turn() + tok.role("assistant")
|
||||
except EOFError:
|
||||
break
|
||||
dec = tok.stream_decoder()
|
||||
for next_id in model.generate(ids):
|
||||
sys.stdout.write(dec(next_id) if not tok.is_end(next_id) else dec() + "\n\n")
|
||||
try: user = input('>>> ')
|
||||
except EOFError: break
|
||||
messages.append({"role": "user", "content": user})
|
||||
ids = chat.apply(messages, add_generation_prompt=True)
|
||||
dec, assistant = tok.stream_decoder(), []
|
||||
for next_id in model.generate(list(ids)):
|
||||
sys.stdout.write(dec(next_id) if not chat.is_end(next_id) else dec() + "\n\n")
|
||||
sys.stdout.flush()
|
||||
if tok.is_end(next_id): break
|
||||
if chat.is_end(next_id): break
|
||||
assistant.append(next_id)
|
||||
messages.append({"role": "assistant", "content": tok.decode(assistant)})
|
||||
|
||||
if __name__ == "__main__": main()
|
||||
if __name__ == "__main__": main()
|
||||
|
|
|
|||
|
|
@ -69,14 +69,19 @@ class TransformerConfig:
|
|||
leading_dense_blocks: int = 0
|
||||
dense_hidden_dim: int = 0
|
||||
routed_scaling_factor: float = 1.0
|
||||
post_norm: bool = False # norms applied to sublayer outputs rather than inputs (OLMo 2)
|
||||
|
||||
class FFNBlock:
|
||||
def __init__(self, config:TransformerConfig):
|
||||
self.config = config
|
||||
|
||||
# --- RMSNorms --------------------------------------------------------
|
||||
self.attn_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
||||
# pre-norm slots (no-op when the model uses post-norm, e.g. OLMo 2)
|
||||
self.attn_norm = (lambda x: x) if config.post_norm else nn.RMSNorm(config.dim, config.norm_eps)
|
||||
self.ffn_norm = (lambda x: x) if config.post_norm else nn.RMSNorm(config.dim, config.norm_eps)
|
||||
# post-norm slots (no-op for standard pre-norm models)
|
||||
self.post_attention_norm = nn.RMSNorm(config.dim, config.norm_eps) if config.post_norm else (lambda x: x)
|
||||
self.post_ffw_norm = nn.RMSNorm(config.dim, config.norm_eps) if config.post_norm else (lambda x: x)
|
||||
|
||||
# --- feed-forward (MoE or dense) -------------------------------------
|
||||
if config.num_experts > 0:
|
||||
|
|
@ -130,8 +135,8 @@ class FFNBlock:
|
|||
# we pass in the weights implicitly so we unpack the GGUF on the fly
|
||||
@function(precompile=True, allow_implicit=True)
|
||||
def _run(x:Tensor, start_pos:int|UOp):
|
||||
h = x + self._attention(self.attn_norm(x), start_pos)
|
||||
return (h + self._feed_forward(self.ffn_norm(h))).contiguous()
|
||||
h = x + self.post_attention_norm(self._attention(self.attn_norm(x), start_pos))
|
||||
return (h + self.post_ffw_norm(self._feed_forward(self.ffn_norm(h)))).contiguous()
|
||||
return _run(x, start_pos)
|
||||
|
||||
class TransformerBlock(FFNBlock):
|
||||
|
|
@ -374,7 +379,7 @@ class Transformer:
|
|||
shared_expert_gate=f"blk.{kv.get(f'{arch}.leading_dense_block_count', 0)}.ffn_gate_inp_shexp.weight" in state_dict,
|
||||
dense_hidden_dim=kv.get(f'{arch}.feed_forward_length', 0) if kv.get(f'{arch}.leading_dense_block_count', 0) else 0,
|
||||
routed_scaling_factor=kv.get(f'{arch}.expert_weights_scale', 1.0), attn_output_gate=arch in ('qwen35', 'qwen35moe'), ssm=ssm,
|
||||
full_attention_interval=kv.get(f'{arch}.full_attention_interval', 0))
|
||||
full_attention_interval=kv.get(f'{arch}.full_attention_interval', 0), post_norm=(arch == 'olmo2'))
|
||||
model = Transformer(config)
|
||||
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
||||
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue