mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llm is the only app (#15779)
* tinygrad/llm is the only app * upd pyproject * claude refs * scoping * min diff
This commit is contained in:
parent
0e69388f6b
commit
ec00cefa5b
16 changed files with 50 additions and 33 deletions
8
.github/workflows/test.yml
vendored
8
.github/workflows/test.yml
vendored
|
|
@ -505,14 +505,14 @@ jobs:
|
||||||
with:
|
with:
|
||||||
key: apps_llm
|
key: apps_llm
|
||||||
- name: Test 1B LLM (llama)
|
- name: Test 1B LLM (llama)
|
||||||
run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model llama3.2:1b | tee /dev/stderr | grep -i rooster
|
run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.llm --model llama3.2:1b | tee /dev/stderr | grep -i rooster
|
||||||
- name: Test 1B LLM (llama q4)
|
- name: Test 1B LLM (llama q4)
|
||||||
run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model llama3.2:1b-q4 | tee /dev/stderr | grep -i rooster
|
run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.llm --model llama3.2:1b-q4 | tee /dev/stderr | grep -i rooster
|
||||||
- name: Test 1B LLM (qwen3.5)
|
- name: Test 1B LLM (qwen3.5)
|
||||||
run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model qwen3.5:0.8b | tee /dev/stderr | grep -i rooster
|
run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.llm --model qwen3.5:0.8b | tee /dev/stderr | grep -i rooster
|
||||||
- name: Test 1B LLM (qwen)
|
- name: Test 1B LLM (qwen)
|
||||||
# NOTE: qwen is dumb and only knows about female chickens
|
# NOTE: qwen is dumb and only knows about female chickens
|
||||||
run: echo "What's a female chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model qwen3:0.6b | tee /dev/stderr | grep -i hen
|
run: echo "What's a female chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.llm --model qwen3:0.6b | tee /dev/stderr | grep -i hen
|
||||||
|
|
||||||
# ****** Models Tests ******
|
# ****** Models Tests ******
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ export PATH="$HOME/.local/bin:$PATH"
|
||||||
### 5. Use it!
|
### 5. Use it!
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
DEV={AMD|NV} python3 tinygrad/apps/llm.py
|
DEV={AMD|NV} python3 -m tinygrad.llm
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note:** Use `JITBEAM=2` to search for faster kernels (one-time search cost, results cached).
|
**Note:** Use `JITBEAM=2` to search for faster kernels (one-time search cost, results cached).
|
||||||
|
|
|
||||||
|
|
@ -19,11 +19,11 @@ build-backend = "setuptools.build_meta"
|
||||||
include-package-data = true
|
include-package-data = true
|
||||||
packages = [
|
packages = [
|
||||||
'tinygrad',
|
'tinygrad',
|
||||||
'tinygrad.apps',
|
|
||||||
'tinygrad.codegen',
|
'tinygrad.codegen',
|
||||||
'tinygrad.codegen.opt',
|
'tinygrad.codegen.opt',
|
||||||
'tinygrad.codegen.late',
|
'tinygrad.codegen.late',
|
||||||
'tinygrad.engine',
|
'tinygrad.engine',
|
||||||
|
'tinygrad.llm',
|
||||||
'tinygrad.mixin',
|
'tinygrad.mixin',
|
||||||
'tinygrad.nn',
|
'tinygrad.nn',
|
||||||
'tinygrad.renderer',
|
'tinygrad.renderer',
|
||||||
|
|
@ -112,9 +112,9 @@ docs = [
|
||||||
[tool.mutmut]
|
[tool.mutmut]
|
||||||
paths_to_mutate = ["tinygrad/"]
|
paths_to_mutate = ["tinygrad/"]
|
||||||
do_not_mutate = [
|
do_not_mutate = [
|
||||||
"tinygrad/apps/*",
|
|
||||||
"tinygrad/codegen/*",
|
"tinygrad/codegen/*",
|
||||||
"tinygrad/engine/*",
|
"tinygrad/engine/*",
|
||||||
|
"tinygrad/llm/*",
|
||||||
"tinygrad/nn/*",
|
"tinygrad/nn/*",
|
||||||
"tinygrad/renderer/*",
|
"tinygrad/renderer/*",
|
||||||
"tinygrad/runtime/*",
|
"tinygrad/runtime/*",
|
||||||
|
|
|
||||||
2
sz.py
2
sz.py
|
|
@ -56,7 +56,7 @@ def gen_diff(table_old, table_new):
|
||||||
|
|
||||||
def display_diff(diff): return "+"+str(diff) if diff > 0 else str(diff)
|
def display_diff(diff): return "+"+str(diff) if diff > 0 else str(diff)
|
||||||
|
|
||||||
NONCORE_DIRS = {"tinygrad/apps", "tinygrad/nn", "tinygrad/renderer", "tinygrad/runtime", "tinygrad/viz"}
|
NONCORE_DIRS = {"tinygrad/llm", "tinygrad/nn", "tinygrad/renderer", "tinygrad/runtime", "tinygrad/viz"}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if len(sys.argv) == 3:
|
if len(sys.argv) == 3:
|
||||||
|
|
|
||||||
2
test/external/external_llm_eval.py
vendored
2
test/external/external_llm_eval.py
vendored
|
|
@ -1,4 +1,4 @@
|
||||||
# eval for tinygrad.apps.llm -- hits the server via OpenAI API
|
# eval for OpenAI API server
|
||||||
# uses Meta's exact ARC-Challenge prompt template from lm-evaluation-harness llama3 tasks
|
# uses Meta's exact ARC-Challenge prompt template from lm-evaluation-harness llama3 tasks
|
||||||
import argparse, re, pyarrow.parquet as pq
|
import argparse, re, pyarrow.parquet as pq
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import functools, multiprocessing
|
import functools, multiprocessing
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from tinygrad.apps.llm import SimpleTokenizer
|
from tinygrad.llm.cli import SimpleTokenizer
|
||||||
from tinygrad.helpers import tqdm, getenv, partition
|
from tinygrad.helpers import tqdm, getenv, partition
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
from tinygrad import Tensor, dtypes, TinyJit, UOp
|
from tinygrad import Tensor, dtypes, TinyJit, UOp
|
||||||
from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis
|
from tinygrad.llm.cli import apply_rope as apply_rope_new, precompute_freqs_cis
|
||||||
from test.helpers import assert_jit_cache_len
|
from test.helpers import assert_jit_cache_len
|
||||||
|
|
||||||
def apply_rope(x:Tensor, start_pos:int):
|
def apply_rope(x:Tensor, start_pos:int):
|
||||||
|
|
|
||||||
|
|
@ -22,18 +22,15 @@ class TestLLMServer(unittest.TestCase):
|
||||||
cls.bos_id = 1
|
cls.bos_id = 1
|
||||||
cls.eos_id = 999
|
cls.eos_id = 999
|
||||||
|
|
||||||
import tinygrad.apps.llm as llm_module
|
from tinygrad.llm.cli import Handler, LLMServer
|
||||||
llm_module.model = cls.mock_model
|
|
||||||
llm_module.model_name = "test-model"
|
|
||||||
llm_module.tok = cls.mock_tok
|
|
||||||
llm_module.bos_id = cls.bos_id
|
|
||||||
llm_module.eos_id = cls.eos_id
|
|
||||||
llm_module.eot_id = None
|
|
||||||
|
|
||||||
from tinygrad.apps.llm import Handler
|
cls.server = LLMServer(('127.0.0.1', 0), Handler)
|
||||||
from tinygrad.viz.serve import TCPServerWithReuse
|
cls.server.model = cls.mock_model
|
||||||
|
cls.server.model_name = "test-model"
|
||||||
cls.server = TCPServerWithReuse(('127.0.0.1', 0), Handler)
|
cls.server.tok = cls.mock_tok
|
||||||
|
cls.server.bos_id = cls.bos_id
|
||||||
|
cls.server.eos_id = cls.eos_id
|
||||||
|
cls.server.eot_id = None
|
||||||
cls.port = cls.server.server_address[1]
|
cls.port = cls.server.server_address[1]
|
||||||
cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True)
|
cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True)
|
||||||
cls.server_thread.start()
|
cls.server_thread.start()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import unittest, base64, functools, sys
|
import unittest, base64, functools, sys
|
||||||
from tinygrad.apps.llm import SimpleTokenizer
|
from tinygrad.llm.cli import SimpleTokenizer
|
||||||
from tinygrad.helpers import fetch
|
from tinygrad.helpers import fetch
|
||||||
|
|
||||||
@unittest.skipIf(sys.platform == 'win32', "fetch race condition on Windows")
|
@unittest.skipIf(sys.platform == 'win32', "fetch race condition on Windows")
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad import Tensor, dtypes
|
from tinygrad import Tensor, dtypes
|
||||||
from tinygrad.apps.llm import (
|
from tinygrad.llm.cli import (
|
||||||
GatedDeltaNetBlock, SSMConfig, TransformerBlock, TransformerConfig,
|
GatedDeltaNetBlock, SSMConfig, TransformerBlock, TransformerConfig,
|
||||||
apply_rope as apply_rope_new, precompute_freqs_cis, pairwise_topk,
|
apply_rope as apply_rope_new, precompute_freqs_cis, pairwise_topk,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad import Tensor
|
from tinygrad import Tensor
|
||||||
from tinygrad.apps.llm import Transformer, TransformerConfig, apply_rope
|
from tinygrad.llm.cli import Transformer, TransformerConfig, apply_rope
|
||||||
|
|
||||||
class TestMLA(unittest.TestCase):
|
class TestMLA(unittest.TestCase):
|
||||||
def _make_config(self, **kwargs):
|
def _make_config(self, **kwargs):
|
||||||
|
|
@ -13,7 +13,7 @@ class TestMLA(unittest.TestCase):
|
||||||
|
|
||||||
def test_mla_attention_matches_naive(self):
|
def test_mla_attention_matches_naive(self):
|
||||||
config = self._make_config(max_context=16)
|
config = self._make_config(max_context=16)
|
||||||
from tinygrad.apps.llm import MLATransformerBlock, precompute_freqs_cis
|
from tinygrad.llm.cli import MLATransformerBlock, precompute_freqs_cis
|
||||||
|
|
||||||
block = MLATransformerBlock(config)
|
block = MLATransformerBlock(config)
|
||||||
c = config
|
c = config
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
from tinygrad import Tensor
|
from tinygrad import Tensor
|
||||||
from tinygrad.apps.llm import TransformerBlock, TransformerConfig
|
from tinygrad.llm.cli import TransformerBlock, TransformerConfig
|
||||||
|
|
||||||
def _moe_config(dim=8, hidden=16, n_heads=2, num_experts=4, num_experts_per_tok=2):
|
def _moe_config(dim=8, hidden=16, n_heads=2, num_experts=4, num_experts_per_tok=2):
|
||||||
return TransformerConfig(
|
return TransformerConfig(
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from tinygrad import Tensor, UOp
|
from tinygrad import Tensor, UOp
|
||||||
from tinygrad.schedule import schedule_cache
|
from tinygrad.schedule import schedule_cache
|
||||||
from tinygrad.apps.llm import Transformer, TransformerConfig
|
from tinygrad.llm.cli import Transformer, TransformerConfig
|
||||||
|
|
||||||
TEST_CONFIG = TransformerConfig(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
TEST_CONFIG = TransformerConfig(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, rope_dim=32, v_head_dim=32, max_context=32)
|
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, rope_dim=32, v_head_dim=32, max_context=32)
|
||||||
|
|
|
||||||
0
tinygrad/llm/__init__.py
Normal file
0
tinygrad/llm/__init__.py
Normal file
2
tinygrad/llm/__main__.py
Normal file
2
tinygrad/llm/__main__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
from tinygrad.llm.cli import main
|
||||||
|
if __name__ == "__main__": main()
|
||||||
|
|
@ -56,7 +56,7 @@ class SimpleTokenizer:
|
||||||
return tokens + self._encode_sentence(text[pos:])
|
return tokens + self._encode_sentence(text[pos:])
|
||||||
|
|
||||||
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode(errors='replace')
|
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode(errors='replace')
|
||||||
def stream_decoder(self) -> typing.Callable[[int|None], str]:
|
def stream_decoder(self) -> typing.Callable[..., str]:
|
||||||
dec = codecs.getincrementaldecoder('utf-8')('replace')
|
dec = codecs.getincrementaldecoder('utf-8')('replace')
|
||||||
def _decode(tid:int|None=None) -> str: return dec.decode(self._tok2bytes[tid]) if tid is not None else dec.decode(b'', final=True)
|
def _decode(tid:int|None=None) -> str: return dec.decode(self._tok2bytes[tid]) if tid is not None else dec.decode(b'', final=True)
|
||||||
return _decode
|
return _decode
|
||||||
|
|
@ -545,12 +545,23 @@ CHAT_HTML = b'''<!DOCTYPE html><html><head><title>tinygrad chat</title><style>
|
||||||
}
|
}
|
||||||
</script></body></html>'''
|
</script></body></html>'''
|
||||||
|
|
||||||
|
class LLMServer(TCPServerWithReuse):
|
||||||
|
model: Transformer
|
||||||
|
model_name: str
|
||||||
|
tok: SimpleTokenizer
|
||||||
|
# TODO: tastefully move these into tokenizer
|
||||||
|
bos_id: int|None
|
||||||
|
eos_id: int
|
||||||
|
eot_id: int|None
|
||||||
|
|
||||||
class Handler(HTTPRequestHandler):
|
class Handler(HTTPRequestHandler):
|
||||||
|
server: LLMServer
|
||||||
def log_request(self, code='-', size='-'): pass
|
def log_request(self, code='-', size='-'): pass
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
if self.path == "/v1/models": self.send_data(json.dumps({"object":"list","data":[{"id":model_name,"object":"model"}]}).encode())
|
if self.path == "/v1/models": self.send_data(json.dumps({"object":"list","data":[{"id":self.server.model_name,"object":"model"}]}).encode())
|
||||||
else: self.send_data(CHAT_HTML, content_type="text/html")
|
else: self.send_data(CHAT_HTML, content_type="text/html")
|
||||||
def run_model(self, ids:list[int], model_name:str, include_usage=False, max_tokens:int|None=None, temperature:float=0.0):
|
def run_model(self, ids:list[int], model_name:str, include_usage=False, max_tokens:int|None=None, temperature:float=0.0):
|
||||||
|
model, tok, eos_id, eot_id = self.server.model, self.server.tok, self.server.eos_id, self.server.eot_id
|
||||||
cache_start_pos = model.get_start_pos(ids)
|
cache_start_pos = model.get_start_pos(ids)
|
||||||
stderr_log(f"{self.path} {colored('--', 'BLACK')} "
|
stderr_log(f"{self.path} {colored('--', 'BLACK')} "
|
||||||
f"in:{colored(f'{cache_start_pos:5d}', 'green')} +{len(ids)-cache_start_pos:5d} {colored('--', 'BLACK')} ")
|
f"in:{colored(f'{cache_start_pos:5d}', 'green')} +{len(ids)-cache_start_pos:5d} {colored('--', 'BLACK')} ")
|
||||||
|
|
@ -577,6 +588,7 @@ class Handler(HTTPRequestHandler):
|
||||||
f"out:{len(out):5d} {colored('--', 'BLACK')} total:{et-st:6.2f}s\n")
|
f"out:{len(out):5d} {colored('--', 'BLACK')} total:{et-st:6.2f}s\n")
|
||||||
|
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
|
tok, bos_id, eos_id = self.server.tok, self.server.bos_id, self.server.eos_id
|
||||||
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
|
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
|
||||||
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
|
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
|
||||||
if DEBUG >= 1: print(json.dumps(body, indent=2))
|
if DEBUG >= 1: print(json.dumps(body, indent=2))
|
||||||
|
|
@ -611,7 +623,7 @@ class Handler(HTTPRequestHandler):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"unhandled path {self.path}")
|
raise RuntimeError(f"unhandled path {self.path}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--model", "-m", default=list(models.keys())[0], help=f"Model choice ({', '.join(models.keys())}) or path to a local GGUF file")
|
parser.add_argument("--model", "-m", default=list(models.keys())[0], help=f"Model choice ({', '.join(models.keys())}) or path to a local GGUF file")
|
||||||
parser.add_argument("--max_context", type=int, default=4096, help="Max Context Length")
|
parser.add_argument("--max_context", type=int, default=4096, help="Max Context Length")
|
||||||
|
|
@ -643,7 +655,11 @@ if __name__ == "__main__":
|
||||||
for _ in range(2): list(zip(range(2), model.generate([0])))
|
for _ in range(2): list(zip(range(2), model.generate([0])))
|
||||||
|
|
||||||
# start server
|
# start server
|
||||||
if args.serve: TCPServerWithReuse(('', args.serve), Handler).serve_forever()
|
if args.serve:
|
||||||
|
server = LLMServer(('', args.serve), Handler)
|
||||||
|
server.model, server.model_name, server.tok = model, model_name, tok
|
||||||
|
server.bos_id, server.eos_id, server.eot_id = bos_id, eos_id, eot_id
|
||||||
|
server.serve_forever()
|
||||||
|
|
||||||
# do benchmark
|
# do benchmark
|
||||||
if args.benchmark is not None:
|
if args.benchmark is not None:
|
||||||
|
|
@ -667,3 +683,5 @@ if __name__ == "__main__":
|
||||||
sys.stdout.write(dec(next_id) if next_id not in (eos_id, eot_id) else dec() + "\n\n")
|
sys.stdout.write(dec(next_id) if next_id not in (eos_id, eot_id) else dec() + "\n\n")
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
if next_id in (eos_id, eot_id): break
|
if next_id in (eos_id, eot_id): break
|
||||||
|
|
||||||
|
if __name__ == "__main__": main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue