mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
3 commits
master
...
olmo3_supp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
72fccb5f5c | ||
|
|
6077327b80 | ||
|
|
79079ca3a9 |
5 changed files with 235 additions and 105 deletions
100
test/null/test_llm_chat.py
Normal file
100
test/null/test_llm_chat.py
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
import unittest
|
||||||
|
from tinygrad.llm.cli import SimpleTokenizer, Chat
|
||||||
|
|
||||||
|
class TestChatSimple(unittest.TestCase):
|
||||||
|
def test_tekken_preset(self):
|
||||||
|
# Tekken (Mistral): role(user)=[INST], role(assistant)=[], end_turn=[/INST].
|
||||||
|
kv = {
|
||||||
|
"tokenizer.ggml.tokens": ["<unk>", "<s>", "</s>", "[INST]", "[/INST]", "hello", "sure"],
|
||||||
|
"tokenizer.ggml.token_type": [3, 3, 3, 3, 3, 1, 1],
|
||||||
|
"tokenizer.ggml.pre": "tekken",
|
||||||
|
"tokenizer.ggml.eos_token_id": 2,
|
||||||
|
}
|
||||||
|
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||||
|
chat = Chat.from_gguf_kv(kv, tok)
|
||||||
|
# single user turn: [INST] hello [/INST]
|
||||||
|
self.assertEqual(chat.apply([{"role": "user", "content": "hello"}]), [3, 5, 4])
|
||||||
|
# user + assistant: [INST] hello [/INST] sure [/INST]
|
||||||
|
self.assertEqual(chat.apply([{"role": "user", "content": "hello"}, {"role": "assistant", "content": "sure"}]),
|
||||||
|
[3, 5, 4, 6, 4])
|
||||||
|
# add_generation_prompt on tekken appends role("assistant") which is []
|
||||||
|
self.assertEqual(chat.apply([{"role": "user", "content": "hello"}], add_generation_prompt=True), [3, 5, 4])
|
||||||
|
|
||||||
|
def test_is_end_basic(self):
|
||||||
|
kv = {"tokenizer.ggml.tokens": ["<unk>", "<eos>"], "tokenizer.ggml.token_type": [3, 3],
|
||||||
|
"tokenizer.ggml.pre": "llama3", "tokenizer.ggml.eos_token_id": 1}
|
||||||
|
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||||
|
chat = Chat(tok, preset="llama3")
|
||||||
|
self.assertTrue(chat.is_end(1))
|
||||||
|
self.assertFalse(chat.is_end(0))
|
||||||
|
|
||||||
|
def test_olmo2_simple_mode(self):
|
||||||
|
# OLMo 2: pre='dbrx' (unsupported) but arch override maps it to qwen2-style chat with <|im_end|> as turn-end.
|
||||||
|
# tokenizer's eos_id must be untouched; Chat widens stop_ids and uses <|im_end|> for end_turn.
|
||||||
|
kv = {
|
||||||
|
"tokenizer.ggml.tokens": ["<|endoftext|>", "hello", "<|im_end|>", "<|im_start|>"],
|
||||||
|
"tokenizer.ggml.token_type": [3, 1, 3, 3],
|
||||||
|
"tokenizer.ggml.pre": "dbrx",
|
||||||
|
"tokenizer.ggml.eos_token_id": 0,
|
||||||
|
"general.architecture": "olmo2",
|
||||||
|
}
|
||||||
|
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||||
|
chat = Chat.from_gguf_kv(kv, tok)
|
||||||
|
self.assertEqual(tok.eos_id, 0) # raw GGUF eos (<|endoftext|>) untouched
|
||||||
|
self.assertEqual(chat.preset, "qwen2") # arch-overridden preset
|
||||||
|
self.assertEqual(chat.turn_end_id, 2) # <|im_end|>
|
||||||
|
self.assertTrue(chat.is_end(0)) # <|endoftext|>
|
||||||
|
self.assertTrue(chat.is_end(2)) # <|im_end|>
|
||||||
|
self.assertFalse(chat.is_end(1))
|
||||||
|
|
||||||
|
def test_assistant_prefill_no_end_turn(self):
|
||||||
|
# continue_final_message should drop the trailing end_turn after the prefill assistant message.
|
||||||
|
kv = {
|
||||||
|
"tokenizer.ggml.tokens": ["<unk>", "<s>", "</s>", "[INST]", "[/INST]", "hello", "sure"],
|
||||||
|
"tokenizer.ggml.token_type": [3, 3, 3, 3, 3, 1, 1],
|
||||||
|
"tokenizer.ggml.pre": "tekken",
|
||||||
|
"tokenizer.ggml.eos_token_id": 2,
|
||||||
|
}
|
||||||
|
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||||
|
chat = Chat.from_gguf_kv(kv, tok)
|
||||||
|
msgs = [{"role": "user", "content": "hello"}, {"role": "assistant", "content": "sure"}]
|
||||||
|
self.assertEqual(chat.apply(msgs), [3, 5, 4, 6, 4])
|
||||||
|
self.assertEqual(chat.apply(msgs, continue_final_message=True), [3, 5, 4, 6])
|
||||||
|
|
||||||
|
def test_unsupported_preset_without_jinja(self):
|
||||||
|
kv = {"tokenizer.ggml.tokens": ["<unk>"], "tokenizer.ggml.token_type": [3], "tokenizer.ggml.pre": "dbrx"}
|
||||||
|
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Chat.from_gguf_kv(kv, tok) # "dbrx" not in _PRESETS and not olmo2 arch
|
||||||
|
# works with use_jinja=True (no simple-preset check), as long as a chat_template is provided
|
||||||
|
kv2 = {**kv, "tokenizer.chat_template": "x"}
|
||||||
|
tok2 = SimpleTokenizer.from_gguf_kv(kv2)
|
||||||
|
Chat.from_gguf_kv(kv2, tok2, use_jinja=True)
|
||||||
|
|
||||||
|
class TestChatJinja(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
try: import jinja2 # noqa: F401
|
||||||
|
except ImportError: self.skipTest("jinja2 not installed")
|
||||||
|
|
||||||
|
def test_render_simple_template(self):
|
||||||
|
template = ("{%- for m in messages %}{%- if m.role == 'user' %}{{- '[INST]' + m.content + '[/INST]' }}"
|
||||||
|
"{%- elif m.role == 'assistant' %}{{- m.content + '</s>' }}{%- endif %}{%- endfor %}")
|
||||||
|
kv = {
|
||||||
|
"tokenizer.ggml.tokens": ["<unk>", "<s>", "</s>", "[INST]", "[/INST]", "hello"],
|
||||||
|
"tokenizer.ggml.token_type": [3, 3, 3, 3, 3, 1],
|
||||||
|
"tokenizer.ggml.pre": "tekken",
|
||||||
|
"tokenizer.ggml.eos_token_id": 2,
|
||||||
|
"tokenizer.chat_template": template,
|
||||||
|
}
|
||||||
|
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||||
|
chat = Chat.from_gguf_kv(kv, tok, use_jinja=True)
|
||||||
|
self.assertEqual(chat.apply([{"role": "user", "content": "hello"}]), [3, 5, 4])
|
||||||
|
|
||||||
|
def test_jinja_requires_template(self):
|
||||||
|
kv = {"tokenizer.ggml.tokens": ["<unk>"], "tokenizer.ggml.token_type": [3], "tokenizer.ggml.pre": "llama3"}
|
||||||
|
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Chat.from_gguf_kv(kv, tok, use_jinja=True) # no tokenizer.chat_template
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
|
|
@ -7,17 +7,17 @@ class TestLLMServer(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.mock_tok = Mock()
|
cls.mock_tok = Mock()
|
||||||
cls.mock_tok.role = Mock(return_value=[100, 101])
|
|
||||||
cls.mock_tok.encode = Mock(return_value=[200, 201, 202])
|
cls.mock_tok.encode = Mock(return_value=[200, 201, 202])
|
||||||
cls.mock_tok.decode = Mock(return_value="Hello")
|
cls.mock_tok.decode = Mock(return_value="Hello")
|
||||||
cls.mock_tok.stream_decoder = Mock(return_value=lambda tid=None: "Hello" if tid is not None else "")
|
cls.mock_tok.stream_decoder = Mock(return_value=lambda tid=None: "Hello" if tid is not None else "")
|
||||||
cls.mock_tok.end_turn = Mock(return_value=[998])
|
|
||||||
cls.mock_tok.prefix = Mock(return_value=[1])
|
|
||||||
cls.mock_tok.preset = "llama3"
|
|
||||||
cls.mock_tok.bos_id = 1
|
cls.mock_tok.bos_id = 1
|
||||||
cls.mock_tok.eos_id = 999
|
cls.mock_tok.eos_id = 999
|
||||||
cls.mock_tok.eot_id = None
|
cls.mock_tok.eot_id = None
|
||||||
cls.mock_tok.is_end = Mock(side_effect=lambda tid: tid in (999,))
|
|
||||||
|
cls.mock_chat = Mock()
|
||||||
|
cls.mock_chat.tok = cls.mock_tok
|
||||||
|
cls.mock_chat.apply = Mock(return_value=[1, 100, 101, 200, 201, 202, 998, 100, 101])
|
||||||
|
cls.mock_chat.is_end = Mock(side_effect=lambda tid: tid in (999,))
|
||||||
|
|
||||||
cls.mock_model = Mock()
|
cls.mock_model = Mock()
|
||||||
cls.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 999]))
|
cls.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 999]))
|
||||||
|
|
@ -25,7 +25,7 @@ class TestLLMServer(unittest.TestCase):
|
||||||
|
|
||||||
from tinygrad.llm.cli import LLMServer
|
from tinygrad.llm.cli import LLMServer
|
||||||
|
|
||||||
cls.server = LLMServer(('127.0.0.1', 0), cls.mock_model, "test-model", cls.mock_tok)
|
cls.server = LLMServer(('127.0.0.1', 0), cls.mock_model, "test-model", cls.mock_chat)
|
||||||
cls.port = cls.server.server_address[1]
|
cls.port = cls.server.server_address[1]
|
||||||
cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True)
|
cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True)
|
||||||
cls.server_thread.start()
|
cls.server_thread.start()
|
||||||
|
|
@ -150,35 +150,24 @@ class TestLLMServer(unittest.TestCase):
|
||||||
self.assertEqual(resp.usage.completion_tokens, 2)
|
self.assertEqual(resp.usage.completion_tokens, 2)
|
||||||
|
|
||||||
def test_assistant_prefill(self):
|
def test_assistant_prefill(self):
|
||||||
"""Last assistant message should be treated as prefill (not a completed turn)."""
|
"""Last assistant message should be treated as prefill (continue_final_message=True)."""
|
||||||
self.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 999]))
|
self.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 999]))
|
||||||
captured_ids = []
|
self.mock_chat.apply.reset_mock()
|
||||||
orig_generate = self.mock_model.generate.side_effect
|
|
||||||
def capture_generate(ids, **kwargs):
|
|
||||||
captured_ids.extend(ids)
|
|
||||||
return orig_generate(ids, **kwargs)
|
|
||||||
self.mock_model.generate = Mock(side_effect=capture_generate)
|
|
||||||
|
|
||||||
resp = self.client.chat.completions.create(
|
resp = self.client.chat.completions.create(
|
||||||
model="test", messages=[
|
model="test", messages=[
|
||||||
{"role": "user", "content": "Hello"},
|
{"role": "user", "content": "Hello"},
|
||||||
{"role": "assistant", "content": "Sure"}
|
{"role": "assistant", "content": "Sure"}
|
||||||
], stream=False
|
], stream=False
|
||||||
)
|
)
|
||||||
# prefill tokens should be in ids: role("assistant") + encode("Sure") but NO end_turn after it
|
call = self.mock_chat.apply.call_args
|
||||||
# and NO extra role("assistant") appended
|
self.assertTrue(call.kwargs["continue_final_message"])
|
||||||
role_tokens = self.mock_tok.role.call_args_list
|
self.assertFalse(call.kwargs["add_generation_prompt"])
|
||||||
# last role() call should be for "assistant" (the prefill message), not an extra one
|
|
||||||
self.assertEqual(role_tokens[-1], unittest.mock.call("assistant"))
|
|
||||||
# end_turn should be called once less than role() — the prefill assistant msg doesn't get end_turn
|
|
||||||
self.assertEqual(self.mock_tok.end_turn.call_count, self.mock_tok.role.call_count - 1)
|
|
||||||
self.assertIsNotNone(resp.choices[0].message.content)
|
self.assertIsNotNone(resp.choices[0].message.content)
|
||||||
|
|
||||||
def test_assistant_prefill_not_last(self):
|
def test_assistant_prefill_not_last(self):
|
||||||
"""Assistant message that's NOT last should be a normal completed turn."""
|
"""Assistant message that's NOT last should be a normal completed turn (add_generation_prompt=True)."""
|
||||||
self.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 999]))
|
self.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 999]))
|
||||||
self.mock_tok.role.reset_mock()
|
self.mock_chat.apply.reset_mock()
|
||||||
self.mock_tok.end_turn.reset_mock()
|
|
||||||
self.client.chat.completions.create(
|
self.client.chat.completions.create(
|
||||||
model="test", messages=[
|
model="test", messages=[
|
||||||
{"role": "user", "content": "Hello"},
|
{"role": "user", "content": "Hello"},
|
||||||
|
|
@ -186,11 +175,9 @@ class TestLLMServer(unittest.TestCase):
|
||||||
{"role": "user", "content": "Continue"}
|
{"role": "user", "content": "Continue"}
|
||||||
], stream=False
|
], stream=False
|
||||||
)
|
)
|
||||||
# all messages get end_turn, plus an extra role("assistant") at the end
|
call = self.mock_chat.apply.call_args
|
||||||
# roles: user, assistant, user, assistant(generation prompt) = 4 role calls
|
self.assertFalse(call.kwargs["continue_final_message"])
|
||||||
# end_turns: user, assistant, user = 3 end_turn calls (one per message)
|
self.assertTrue(call.kwargs["add_generation_prompt"])
|
||||||
self.assertEqual(self.mock_tok.end_turn.call_count, 3)
|
|
||||||
self.assertEqual(self.mock_tok.role.call_count, 4)
|
|
||||||
|
|
||||||
def test_models_endpoint(self):
|
def test_models_endpoint(self):
|
||||||
import requests as req
|
import requests as req
|
||||||
|
|
|
||||||
|
|
@ -46,19 +46,6 @@ class TestLLMTokenizer(unittest.TestCase):
|
||||||
def test_llama_repeat(self): self._test_coding(self.llama_tok, "00000000000000000", [ 931, 931, 931, 931, 931, 410 ])
|
def test_llama_repeat(self): self._test_coding(self.llama_tok, "00000000000000000", [ 931, 931, 931, 931, 931, 410 ])
|
||||||
def test_llama_pat(self): self._test_coding(self.llama_tok, "today\n \n", [ 31213, 14211 ])
|
def test_llama_pat(self): self._test_coding(self.llama_tok, "today\n \n", [ 31213, 14211 ])
|
||||||
|
|
||||||
def test_tekken_from_gguf_kv(self):
|
|
||||||
kv = {
|
|
||||||
"tokenizer.ggml.tokens": ["<unk>", "<s>", "</s>", "[INST]", "[/INST]", "hello"],
|
|
||||||
"tokenizer.ggml.token_type": [3, 3, 3, 3, 3, 1],
|
|
||||||
"tokenizer.ggml.pre": "tekken",
|
|
||||||
"tokenizer.ggml.eos_token_id": 2,
|
|
||||||
}
|
|
||||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
|
||||||
self.assertEqual(tok.role("user"), [3])
|
|
||||||
self.assertEqual(tok.encode("hello"), [5])
|
|
||||||
self.assertEqual(tok.end_turn(), [4])
|
|
||||||
self.assertEqual(tok.role("assistant"), [])
|
|
||||||
|
|
||||||
def test_stream_decoder(self):
|
def test_stream_decoder(self):
|
||||||
"""stream_decoder buffers incomplete UTF-8: token 25677 has 3/4 of emoji, token 138 completes it."""
|
"""stream_decoder buffers incomplete UTF-8: token 25677 has 3/4 of emoji, token 138 completes it."""
|
||||||
bs = [*range(33, 127), *range(161, 173), *range(174, 256)]
|
bs = [*range(33, 127), *range(161, 173), *range(174, 256)]
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,8 @@ from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
|
||||||
from tinygrad.llm.model import Transformer
|
from tinygrad.llm.model import Transformer
|
||||||
|
|
||||||
class SimpleTokenizer:
|
class SimpleTokenizer:
|
||||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3",
|
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int],
|
||||||
bos_id:int|None=None, eos_id:int=0, eot_id:int|None=None):
|
bos_id:int|None=None, eos_id:int=0, eot_id:int|None=None):
|
||||||
preset = {"qwen35":"qwen2","qwen35moe":"qwen2"}.get(preset, preset)
|
|
||||||
if preset not in ("llama3","llama-v3","llama-bpe","qwen2","olmo","kimi-k2","tekken","glm4"):
|
|
||||||
raise ValueError(f"Invalid tokenizer preset '{preset}'")
|
|
||||||
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
||||||
bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves
|
bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves
|
||||||
self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
|
self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
|
||||||
|
|
@ -26,7 +23,6 @@ class SimpleTokenizer:
|
||||||
self._normal_tokens = {bytes(self._byte_decoder[c] for c in tok): tid for tok, tid in normal_tokens.items()}
|
self._normal_tokens = {bytes(self._byte_decoder[c] for c in tok): tid for tok, tid in normal_tokens.items()}
|
||||||
self._special_tokens = special_tokens
|
self._special_tokens = special_tokens
|
||||||
self._tok2bytes = {tid: tok for tok, tid in self._normal_tokens.items()} | {tid: tok.encode() for tok, tid in self._special_tokens.items()}
|
self._tok2bytes = {tid: tok for tok, tid in self._normal_tokens.items()} | {tid: tok.encode() for tok, tid in self._special_tokens.items()}
|
||||||
self.preset = preset
|
|
||||||
self.bos_id, self.eos_id, self.eot_id = bos_id, eos_id, eot_id
|
self.bos_id, self.eos_id, self.eot_id = bos_id, eos_id, eot_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -34,7 +30,7 @@ class SimpleTokenizer:
|
||||||
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
|
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
|
||||||
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
|
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
|
||||||
normal_tokens, special_tokens = partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
|
normal_tokens, special_tokens = partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
|
||||||
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens), kv["tokenizer.ggml.pre"],
|
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens),
|
||||||
bos_id=kv.get('tokenizer.ggml.bos_token_id') if kv.get('tokenizer.ggml.add_bos_token', True) else None,
|
bos_id=kv.get('tokenizer.ggml.bos_token_id') if kv.get('tokenizer.ggml.add_bos_token', True) else None,
|
||||||
eos_id=kv.get('tokenizer.ggml.eos_token_id', 0), eot_id=kv.get('tokenizer.ggml.eot_token_id'))
|
eos_id=kv.get('tokenizer.ggml.eos_token_id', 0), eot_id=kv.get('tokenizer.ggml.eot_token_id'))
|
||||||
|
|
||||||
|
|
@ -63,26 +59,87 @@ class SimpleTokenizer:
|
||||||
dec = codecs.getincrementaldecoder('utf-8')('replace')
|
dec = codecs.getincrementaldecoder('utf-8')('replace')
|
||||||
def _decode(tid:int|None=None) -> str: return dec.decode(self._tok2bytes[tid]) if tid is not None else dec.decode(b'', final=True)
|
def _decode(tid:int|None=None) -> str: return dec.decode(self._tok2bytes[tid]) if tid is not None else dec.decode(b'', final=True)
|
||||||
return _decode
|
return _decode
|
||||||
def role(self, role:str):
|
|
||||||
if self.preset == 'olmo': return self.encode("<|" + role + "|>\n") # OLMoE Instruct format
|
def _flatten_content(c) -> str:
|
||||||
if self.preset == 'kimi-k2': return self.encode("<|im_" + role + "|>" + role + "<|im_middle|>")
|
return c if isinstance(c, str) else "".join(p["text"] for p in c if p.get("type") == "text")
|
||||||
if self.preset == 'qwen2': return self.encode("<|im_start|>" + role + "\n")
|
|
||||||
if self.preset == 'glm4': return self.encode("<|" + role + "|>")
|
class Chat:
|
||||||
if self.preset == 'tekken':
|
"""Formats messages into tokens for a given model.
|
||||||
if role == 'user': return self.encode("[INST]")
|
|
||||||
if role == 'assistant': return []
|
Two modes:
|
||||||
raise ValueError(f"Unsupported role '{role}' for tokenizer preset '{self.preset}'")
|
- default (simple): uses a small hard-coded preset dispatch keyed on `tokenizer.ggml.pre` for the formatting.
|
||||||
return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
Covers llama3/qwen2/olmo/kimi-k2/tekken/glm4 chat formats.
|
||||||
def end_turn(self):
|
- `use_jinja=True`: renders the GGUF's `tokenizer.chat_template` with the real `jinja2` package.
|
||||||
if self.preset == 'olmo': return self.encode("\n")
|
Needed for templates using features outside the simple preset set (e.g. Qwen 3.5's macros).
|
||||||
if self.preset == 'kimi-k2': return [self.eos_id]
|
"""
|
||||||
if self.preset == 'qwen2': return [self.eos_id] + self.encode("\n")
|
_PRESETS = ("llama3", "llama-v3", "llama-bpe", "qwen2", "olmo", "kimi-k2", "tekken", "glm4")
|
||||||
if self.preset == 'glm4': return []
|
|
||||||
if self.preset == 'tekken': return self.encode("[/INST]")
|
def __init__(self, tok:SimpleTokenizer, template:str|None=None, preset:str="llama3",
|
||||||
return [self.eos_id]
|
use_jinja:bool=False, extra_stop_ids:typing.Iterable[int]=(), turn_end_id:int|None=None):
|
||||||
def prefix(self) -> list[int]:
|
self.tok, self.template, self.use_jinja = tok, template, use_jinja
|
||||||
return ([] if self.bos_id is None else [self.bos_id]) + (self.encode("<sop>") if self.preset == 'glm4' else [])
|
self.preset = {"qwen35":"qwen2", "qwen35moe":"qwen2"}.get(preset, preset)
|
||||||
def is_end(self, token_id:int) -> bool: return token_id in (self.eos_id, self.eot_id)
|
self.turn_end_id = turn_end_id if turn_end_id is not None else tok.eos_id
|
||||||
|
self.stop_ids: set[int] = {x for x in (tok.eos_id, tok.eot_id, self.turn_end_id, *extra_stop_ids) if x is not None}
|
||||||
|
if use_jinja:
|
||||||
|
if template is None: raise ValueError("use_jinja=True requires tokenizer.chat_template in the GGUF")
|
||||||
|
elif self.preset not in self._PRESETS:
|
||||||
|
raise ValueError(f"unsupported tokenizer preset {self.preset!r}; pass use_jinja=True to use the GGUF chat_template instead")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_gguf_kv(kv:dict, tok:SimpleTokenizer, use_jinja:bool=False) -> 'Chat':
|
||||||
|
preset = kv.get('tokenizer.ggml.pre', 'llama3')
|
||||||
|
extra: list[int] = []
|
||||||
|
turn_end_id: int|None = None
|
||||||
|
# OLMo 2: tokenizer.ggml.pre is "dbrx" but the chat format is qwen2-style (<|im_start|>.../<|im_end|>).
|
||||||
|
# <|im_end|> terminates turns (not <|endoftext|>) but both should stop generation.
|
||||||
|
if kv.get('general.architecture') == 'olmo2':
|
||||||
|
preset = 'qwen2'
|
||||||
|
im_end = next((i for i,t in enumerate(kv['tokenizer.ggml.tokens']) if t == '<|im_end|>'), None)
|
||||||
|
if im_end is not None: extra.append(im_end); turn_end_id = im_end
|
||||||
|
return Chat(tok, kv.get('tokenizer.chat_template'), preset, use_jinja, extra, turn_end_id)
|
||||||
|
|
||||||
|
def is_end(self, token_id:int) -> bool: return token_id in self.stop_ids
|
||||||
|
|
||||||
|
def apply(self, messages:list[dict], add_generation_prompt:bool=False, continue_final_message:bool=False) -> list[int]:
|
||||||
|
return (self._apply_jinja if self.use_jinja else self._apply_simple)(messages, add_generation_prompt, continue_final_message)
|
||||||
|
|
||||||
|
def _apply_jinja(self, messages, add_generation_prompt, continue_final_message):
|
||||||
|
try: import jinja2
|
||||||
|
except ImportError as e: raise RuntimeError("use_jinja=True requires the jinja2 package: pip install jinja2") from e
|
||||||
|
def tok_str(tid): return self.tok._tok2bytes[tid].decode(errors='replace') if tid is not None and tid in self.tok._tok2bytes else ''
|
||||||
|
bos, eos, t = tok_str(self.tok.bos_id), tok_str(self.tok.eos_id), jinja2.Template(self.template)
|
||||||
|
if continue_final_message:
|
||||||
|
assert messages and messages[-1]["role"] == "assistant", "continue_final_message requires trailing assistant message"
|
||||||
|
head = t.render(messages=messages[:-1], add_generation_prompt=True, bos_token=bos, eos_token=eos)
|
||||||
|
return self.tok.encode(head + _flatten_content(messages[-1]["content"]))
|
||||||
|
return self.tok.encode(t.render(messages=messages, add_generation_prompt=add_generation_prompt, bos_token=bos, eos_token=eos))
|
||||||
|
|
||||||
|
def _apply_simple(self, messages, add_generation_prompt, continue_final_message):
|
||||||
|
tok, p, e = self.tok, self.preset, self.turn_end_id
|
||||||
|
|
||||||
|
# role header template (role name interpolated as {0}); llama3 is the default
|
||||||
|
role_tmpl = {'qwen2': "<|im_start|>{0}\n",
|
||||||
|
'olmo': "<|{0}|>\n",
|
||||||
|
'kimi-k2': "<|im_{0}|>{0}<|im_middle|>",
|
||||||
|
'glm4': "<|{0}|>"}.get(p, "<|start_header_id|>{0}<|end_header_id|>\n\n")
|
||||||
|
def role(r): # tekken is asymmetric: empty header for assistant
|
||||||
|
if p == 'tekken': return tok.encode("[INST]") if r == "user" else []
|
||||||
|
return tok.encode(role_tmpl.format(r))
|
||||||
|
|
||||||
|
# end-of-turn token ids; llama3 is the default
|
||||||
|
if p == 'qwen2': end_turn = [e, *tok.encode("\n")]
|
||||||
|
elif p == 'olmo': end_turn = tok.encode("\n")
|
||||||
|
elif p == 'glm4': end_turn = []
|
||||||
|
elif p == 'tekken': end_turn = tok.encode("[/INST]")
|
||||||
|
else: end_turn = [e] # llama3, kimi-k2
|
||||||
|
|
||||||
|
prefill = continue_final_message and messages and messages[-1]["role"] == "assistant"
|
||||||
|
ids = ([tok.bos_id] if tok.bos_id is not None else []) + (tok.encode("<sop>") if p == 'glm4' else [])
|
||||||
|
for i, m in enumerate(messages):
|
||||||
|
ids += role(m["role"]) + tok.encode(_flatten_content(m["content"]))
|
||||||
|
if not prefill or i < len(messages) - 1: ids += end_turn
|
||||||
|
if add_generation_prompt and not prefill: ids += role("assistant")
|
||||||
|
return ids
|
||||||
|
|
||||||
models = {
|
models = {
|
||||||
"llama3.2:1b": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
|
"llama3.2:1b": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
|
||||||
|
|
@ -113,7 +170,7 @@ class Handler(HTTPRequestHandler):
|
||||||
if self.path == "/v1/models": self.send_data(json.dumps({"object":"list","data":[{"id":self.server.model_name,"object":"model"}]}).encode())
|
if self.path == "/v1/models": self.send_data(json.dumps({"object":"list","data":[{"id":self.server.model_name,"object":"model"}]}).encode())
|
||||||
else: self.send_data((pathlib.Path(__file__).parent / "chat.html").read_bytes(), content_type="text/html")
|
else: self.send_data((pathlib.Path(__file__).parent / "chat.html").read_bytes(), content_type="text/html")
|
||||||
def run_model(self, ids:list[int], model_name:str, include_usage=False, max_tokens:int|None=None, temperature:float=0.0):
|
def run_model(self, ids:list[int], model_name:str, include_usage=False, max_tokens:int|None=None, temperature:float=0.0):
|
||||||
model, tok = self.server.model, self.server.tok
|
model, chat = self.server.model, self.server.chat
|
||||||
cache_start_pos = model.get_start_pos(ids)
|
cache_start_pos = model.get_start_pos(ids)
|
||||||
stderr_log(f"{self.path} {colored('--', 'BLACK')} "
|
stderr_log(f"{self.path} {colored('--', 'BLACK')} "
|
||||||
f"in:{colored(f'{cache_start_pos:5d}', 'green')} +{len(ids)-cache_start_pos:5d} {colored('--', 'BLACK')} ")
|
f"in:{colored(f'{cache_start_pos:5d}', 'green')} +{len(ids)-cache_start_pos:5d} {colored('--', 'BLACK')} ")
|
||||||
|
|
@ -122,10 +179,10 @@ class Handler(HTTPRequestHandler):
|
||||||
out: list[int] = []
|
out: list[int] = []
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
dec = tok.stream_decoder()
|
dec = chat.tok.stream_decoder()
|
||||||
for next_id in model.generate(ids, temperature=temperature):
|
for next_id in model.generate(ids, temperature=temperature):
|
||||||
if len(out) == 0: stderr_log(f"prefill:{(len(ids)-cache_start_pos)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ")
|
if len(out) == 0: stderr_log(f"prefill:{(len(ids)-cache_start_pos)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ")
|
||||||
if tok.is_end(next_id): break
|
if chat.is_end(next_id): break
|
||||||
out.append(next_id)
|
out.append(next_id)
|
||||||
yield {"choices": [{"index":0, "delta":{"content":dec(next_id)}, "finish_reason":None}], **tmpl}
|
yield {"choices": [{"index":0, "delta":{"content":dec(next_id)}, "finish_reason":None}], **tmpl}
|
||||||
if max_tokens is not None and len(out) >= max_tokens:
|
if max_tokens is not None and len(out) >= max_tokens:
|
||||||
|
|
@ -140,25 +197,15 @@ class Handler(HTTPRequestHandler):
|
||||||
f"out:{len(out):5d} {colored('--', 'BLACK')} total:{et-st:6.2f}s\n")
|
f"out:{len(out):5d} {colored('--', 'BLACK')} total:{et-st:6.2f}s\n")
|
||||||
|
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
tok = self.server.tok
|
chat = self.server.chat
|
||||||
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
|
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
|
||||||
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
|
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
|
||||||
if DEBUG >= 1: print(json.dumps(body, indent=2))
|
if DEBUG >= 1: print(json.dumps(body, indent=2))
|
||||||
if self.path == "/v1/chat/completions":
|
if self.path == "/v1/chat/completions":
|
||||||
# extract tokens, last assistant message is treated as prefill
|
# extract tokens, last assistant message is treated as prefill
|
||||||
ids: list[int] = tok.prefix()
|
messages = [{"role": m["role"], "content": _flatten_content(m["content"])} for m in body["messages"]]
|
||||||
for i, msg in enumerate(body["messages"]):
|
prefill = bool(messages) and messages[-1]["role"] == "assistant"
|
||||||
ids += tok.role(msg["role"])
|
ids = chat.apply(messages, add_generation_prompt=not prefill, continue_final_message=prefill)
|
||||||
content = msg["content"]
|
|
||||||
if isinstance(content, str): ids += tok.encode(content)
|
|
||||||
elif isinstance(content, list):
|
|
||||||
for c in content:
|
|
||||||
if c["type"] == "text": ids += tok.encode(c["text"])
|
|
||||||
else: raise RuntimeError(f"unhandled type: {c['type']}")
|
|
||||||
else: raise RuntimeError(f"unknown content type: {type(content)}")
|
|
||||||
if msg["role"] == "assistant" and i == len(body["messages"]) - 1: break
|
|
||||||
ids += tok.end_turn()
|
|
||||||
else: ids += tok.role("assistant")
|
|
||||||
|
|
||||||
# reply
|
# reply
|
||||||
max_tokens = body.get("max_completion_tokens") or body.get("max_tokens")
|
max_tokens = body.get("max_completion_tokens") or body.get("max_tokens")
|
||||||
|
|
@ -176,8 +223,8 @@ class Handler(HTTPRequestHandler):
|
||||||
raise RuntimeError(f"unhandled path {self.path}")
|
raise RuntimeError(f"unhandled path {self.path}")
|
||||||
|
|
||||||
class LLMServer(TCPServerWithReuse):
|
class LLMServer(TCPServerWithReuse):
|
||||||
def __init__(self, server_address:tuple, model:Transformer, model_name:str, tok:SimpleTokenizer):
|
def __init__(self, server_address:tuple, model:Transformer, model_name:str, chat:Chat):
|
||||||
self.model, self.model_name, self.tok = model, model_name, tok
|
self.model, self.model_name, self.chat = model, model_name, chat
|
||||||
super().__init__(server_address, Handler)
|
super().__init__(server_address, Handler)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
@ -187,6 +234,7 @@ def main():
|
||||||
parser.add_argument("--serve", nargs='?', type=int, const=8000, metavar="PORT", help="Run OpenAI compatible API (optional port, default 8000)")
|
parser.add_argument("--serve", nargs='?', type=int, const=8000, metavar="PORT", help="Run OpenAI compatible API (optional port, default 8000)")
|
||||||
parser.add_argument("--warmup", action="store_true", help="warmup the JIT")
|
parser.add_argument("--warmup", action="store_true", help="warmup the JIT")
|
||||||
parser.add_argument("--benchmark", nargs='?', type=int, const=20, metavar="COUNT", help="Benchmark tok/s (optional count, default 20)")
|
parser.add_argument("--benchmark", nargs='?', type=int, const=20, metavar="COUNT", help="Benchmark tok/s (optional count, default 20)")
|
||||||
|
parser.add_argument("--jinja", action="store_true", help="Render the GGUF chat_template with the real jinja2 package (needs `pip install jinja2`)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# load the model
|
# load the model
|
||||||
|
|
@ -196,8 +244,9 @@ def main():
|
||||||
print(f"using model \"{model_name}\" with {raw_model.nbytes():,} bytes and {sum(x.numel() for x in nn.state.get_parameters(model)):,} params")
|
print(f"using model \"{model_name}\" with {raw_model.nbytes():,} bytes and {sum(x.numel() for x in nn.state.get_parameters(model)):,} params")
|
||||||
del raw_model
|
del raw_model
|
||||||
|
|
||||||
# get tokenizer
|
# get tokenizer and chat formatter
|
||||||
tok = SimpleTokenizer.from_gguf_kv(kv)
|
tok = SimpleTokenizer.from_gguf_kv(kv)
|
||||||
|
chat = Chat.from_gguf_kv(kv, tok, use_jinja=args.jinja)
|
||||||
|
|
||||||
# warmup the JIT
|
# warmup the JIT
|
||||||
if args.warmup or args.serve:
|
if args.warmup or args.serve:
|
||||||
|
|
@ -206,7 +255,7 @@ def main():
|
||||||
for _ in range(2): list(zip(range(2), model.generate([0])))
|
for _ in range(2): list(zip(range(2), model.generate([0])))
|
||||||
|
|
||||||
# start server
|
# start server
|
||||||
if args.serve: LLMServer(('', args.serve), model, model_name, tok).serve_forever()
|
if args.serve: LLMServer(('', args.serve), model, model_name, chat).serve_forever()
|
||||||
|
|
||||||
# do benchmark
|
# do benchmark
|
||||||
if args.benchmark is not None:
|
if args.benchmark is not None:
|
||||||
|
|
@ -218,17 +267,19 @@ def main():
|
||||||
tok.decode(toks).replace("\n", "\\n")): next(gen)
|
tok.decode(toks).replace("\n", "\\n")): next(gen)
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
# interactive chat
|
# interactive chat (falls back to pure completion when the GGUF has no chat template)
|
||||||
ids: list[int] = tok.prefix()
|
messages: list[dict] = []
|
||||||
while 1:
|
while 1:
|
||||||
try:
|
try: user = input('>>> ')
|
||||||
ids += tok.role("user") + tok.encode(input('>>> ')) + tok.end_turn() + tok.role("assistant")
|
except EOFError: break
|
||||||
except EOFError:
|
messages.append({"role": "user", "content": user})
|
||||||
break
|
ids = chat.apply(messages, add_generation_prompt=True)
|
||||||
dec = tok.stream_decoder()
|
dec, assistant = tok.stream_decoder(), []
|
||||||
for next_id in model.generate(ids):
|
for next_id in model.generate(list(ids)):
|
||||||
sys.stdout.write(dec(next_id) if not tok.is_end(next_id) else dec() + "\n\n")
|
sys.stdout.write(dec(next_id) if not chat.is_end(next_id) else dec() + "\n\n")
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
if tok.is_end(next_id): break
|
if chat.is_end(next_id): break
|
||||||
|
assistant.append(next_id)
|
||||||
|
messages.append({"role": "assistant", "content": tok.decode(assistant)})
|
||||||
|
|
||||||
if __name__ == "__main__": main()
|
if __name__ == "__main__": main()
|
||||||
|
|
@ -69,14 +69,19 @@ class TransformerConfig:
|
||||||
leading_dense_blocks: int = 0
|
leading_dense_blocks: int = 0
|
||||||
dense_hidden_dim: int = 0
|
dense_hidden_dim: int = 0
|
||||||
routed_scaling_factor: float = 1.0
|
routed_scaling_factor: float = 1.0
|
||||||
|
post_norm: bool = False # norms applied to sublayer outputs rather than inputs (OLMo 2)
|
||||||
|
|
||||||
class FFNBlock:
|
class FFNBlock:
|
||||||
def __init__(self, config:TransformerConfig):
|
def __init__(self, config:TransformerConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# --- RMSNorms --------------------------------------------------------
|
# --- RMSNorms --------------------------------------------------------
|
||||||
self.attn_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
# pre-norm slots (no-op when the model uses post-norm, e.g. OLMo 2)
|
||||||
self.ffn_norm = nn.RMSNorm(config.dim, config.norm_eps)
|
self.attn_norm = (lambda x: x) if config.post_norm else nn.RMSNorm(config.dim, config.norm_eps)
|
||||||
|
self.ffn_norm = (lambda x: x) if config.post_norm else nn.RMSNorm(config.dim, config.norm_eps)
|
||||||
|
# post-norm slots (no-op for standard pre-norm models)
|
||||||
|
self.post_attention_norm = nn.RMSNorm(config.dim, config.norm_eps) if config.post_norm else (lambda x: x)
|
||||||
|
self.post_ffw_norm = nn.RMSNorm(config.dim, config.norm_eps) if config.post_norm else (lambda x: x)
|
||||||
|
|
||||||
# --- feed-forward (MoE or dense) -------------------------------------
|
# --- feed-forward (MoE or dense) -------------------------------------
|
||||||
if config.num_experts > 0:
|
if config.num_experts > 0:
|
||||||
|
|
@ -130,8 +135,8 @@ class FFNBlock:
|
||||||
# we pass in the weights implicitly so we unpack the GGUF on the fly
|
# we pass in the weights implicitly so we unpack the GGUF on the fly
|
||||||
@function(precompile=True, allow_implicit=True)
|
@function(precompile=True, allow_implicit=True)
|
||||||
def _run(x:Tensor, start_pos:int|UOp):
|
def _run(x:Tensor, start_pos:int|UOp):
|
||||||
h = x + self._attention(self.attn_norm(x), start_pos)
|
h = x + self.post_attention_norm(self._attention(self.attn_norm(x), start_pos))
|
||||||
return (h + self._feed_forward(self.ffn_norm(h))).contiguous()
|
return (h + self.post_ffw_norm(self._feed_forward(self.ffn_norm(h)))).contiguous()
|
||||||
return _run(x, start_pos)
|
return _run(x, start_pos)
|
||||||
|
|
||||||
class TransformerBlock(FFNBlock):
|
class TransformerBlock(FFNBlock):
|
||||||
|
|
@ -374,7 +379,7 @@ class Transformer:
|
||||||
shared_expert_gate=f"blk.{kv.get(f'{arch}.leading_dense_block_count', 0)}.ffn_gate_inp_shexp.weight" in state_dict,
|
shared_expert_gate=f"blk.{kv.get(f'{arch}.leading_dense_block_count', 0)}.ffn_gate_inp_shexp.weight" in state_dict,
|
||||||
dense_hidden_dim=kv.get(f'{arch}.feed_forward_length', 0) if kv.get(f'{arch}.leading_dense_block_count', 0) else 0,
|
dense_hidden_dim=kv.get(f'{arch}.feed_forward_length', 0) if kv.get(f'{arch}.leading_dense_block_count', 0) else 0,
|
||||||
routed_scaling_factor=kv.get(f'{arch}.expert_weights_scale', 1.0), attn_output_gate=arch in ('qwen35', 'qwen35moe'), ssm=ssm,
|
routed_scaling_factor=kv.get(f'{arch}.expert_weights_scale', 1.0), attn_output_gate=arch in ('qwen35', 'qwen35moe'), ssm=ssm,
|
||||||
full_attention_interval=kv.get(f'{arch}.full_attention_interval', 0))
|
full_attention_interval=kv.get(f'{arch}.full_attention_interval', 0), post_norm=(arch == 'olmo2'))
|
||||||
model = Transformer(config)
|
model = Transformer(config)
|
||||||
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
||||||
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue