llm is the only app (#15779)

* tinygrad/llm is the only app

* upd pyproject

* claude refs

* scoping

* min diff
This commit is contained in:
George Hotz 2026-04-17 10:44:48 +08:00 committed by GitHub
commit ec00cefa5b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 50 additions and 33 deletions

View file

@ -505,14 +505,14 @@ jobs:
with: with:
key: apps_llm key: apps_llm
- name: Test 1B LLM (llama) - name: Test 1B LLM (llama)
run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model llama3.2:1b | tee /dev/stderr | grep -i rooster run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.llm --model llama3.2:1b | tee /dev/stderr | grep -i rooster
- name: Test 1B LLM (llama q4) - name: Test 1B LLM (llama q4)
run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model llama3.2:1b-q4 | tee /dev/stderr | grep -i rooster run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.llm --model llama3.2:1b-q4 | tee /dev/stderr | grep -i rooster
- name: Test 1B LLM (qwen3.5) - name: Test 1B LLM (qwen3.5)
run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model qwen3.5:0.8b | tee /dev/stderr | grep -i rooster run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.llm --model qwen3.5:0.8b | tee /dev/stderr | grep -i rooster
- name: Test 1B LLM (qwen) - name: Test 1B LLM (qwen)
# NOTE: qwen is dumb and only knows about female chickens # NOTE: qwen is dumb and only knows about female chickens
run: echo "What's a female chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model qwen3:0.6b | tee /dev/stderr | grep -i hen run: echo "What's a female chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.llm --model qwen3:0.6b | tee /dev/stderr | grep -i hen
# ****** Models Tests ****** # ****** Models Tests ******

View file

@ -55,7 +55,7 @@ export PATH="$HOME/.local/bin:$PATH"
### 5. Use it! ### 5. Use it!
```bash ```bash
DEV={AMD|NV} python3 tinygrad/apps/llm.py DEV={AMD|NV} python3 -m tinygrad.llm
``` ```
**Note:** Use `JITBEAM=2` to search for faster kernels (one-time search cost, results cached). **Note:** Use `JITBEAM=2` to search for faster kernels (one-time search cost, results cached).

View file

@ -19,11 +19,11 @@ build-backend = "setuptools.build_meta"
include-package-data = true include-package-data = true
packages = [ packages = [
'tinygrad', 'tinygrad',
'tinygrad.apps',
'tinygrad.codegen', 'tinygrad.codegen',
'tinygrad.codegen.opt', 'tinygrad.codegen.opt',
'tinygrad.codegen.late', 'tinygrad.codegen.late',
'tinygrad.engine', 'tinygrad.engine',
'tinygrad.llm',
'tinygrad.mixin', 'tinygrad.mixin',
'tinygrad.nn', 'tinygrad.nn',
'tinygrad.renderer', 'tinygrad.renderer',
@ -112,9 +112,9 @@ docs = [
[tool.mutmut] [tool.mutmut]
paths_to_mutate = ["tinygrad/"] paths_to_mutate = ["tinygrad/"]
do_not_mutate = [ do_not_mutate = [
"tinygrad/apps/*",
"tinygrad/codegen/*", "tinygrad/codegen/*",
"tinygrad/engine/*", "tinygrad/engine/*",
"tinygrad/llm/*",
"tinygrad/nn/*", "tinygrad/nn/*",
"tinygrad/renderer/*", "tinygrad/renderer/*",
"tinygrad/runtime/*", "tinygrad/runtime/*",

2
sz.py
View file

@ -56,7 +56,7 @@ def gen_diff(table_old, table_new):
def display_diff(diff): return "+"+str(diff) if diff > 0 else str(diff) def display_diff(diff): return "+"+str(diff) if diff > 0 else str(diff)
NONCORE_DIRS = {"tinygrad/apps", "tinygrad/nn", "tinygrad/renderer", "tinygrad/runtime", "tinygrad/viz"} NONCORE_DIRS = {"tinygrad/llm", "tinygrad/nn", "tinygrad/renderer", "tinygrad/runtime", "tinygrad/viz"}
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) == 3: if len(sys.argv) == 3:

View file

@ -1,4 +1,4 @@
# eval for tinygrad.apps.llm -- hits the server via OpenAI API # eval for OpenAI API server
# uses Meta's exact ARC-Challenge prompt template from lm-evaluation-harness llama3 tasks # uses Meta's exact ARC-Challenge prompt template from lm-evaluation-harness llama3 tasks
import argparse, re, pyarrow.parquet as pq import argparse, re, pyarrow.parquet as pq
from openai import OpenAI from openai import OpenAI

View file

@ -1,7 +1,7 @@
import functools, multiprocessing import functools, multiprocessing
from transformers import AutoTokenizer from transformers import AutoTokenizer
from datasets import load_dataset from datasets import load_dataset
from tinygrad.apps.llm import SimpleTokenizer from tinygrad.llm.cli import SimpleTokenizer
from tinygrad.helpers import tqdm, getenv, partition from tinygrad.helpers import tqdm, getenv, partition
@functools.cache @functools.cache

View file

@ -1,6 +1,6 @@
import unittest import unittest
from tinygrad import Tensor, dtypes, TinyJit, UOp from tinygrad import Tensor, dtypes, TinyJit, UOp
from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis from tinygrad.llm.cli import apply_rope as apply_rope_new, precompute_freqs_cis
from test.helpers import assert_jit_cache_len from test.helpers import assert_jit_cache_len
def apply_rope(x:Tensor, start_pos:int): def apply_rope(x:Tensor, start_pos:int):

View file

@ -22,18 +22,15 @@ class TestLLMServer(unittest.TestCase):
cls.bos_id = 1 cls.bos_id = 1
cls.eos_id = 999 cls.eos_id = 999
import tinygrad.apps.llm as llm_module from tinygrad.llm.cli import Handler, LLMServer
llm_module.model = cls.mock_model
llm_module.model_name = "test-model"
llm_module.tok = cls.mock_tok
llm_module.bos_id = cls.bos_id
llm_module.eos_id = cls.eos_id
llm_module.eot_id = None
from tinygrad.apps.llm import Handler cls.server = LLMServer(('127.0.0.1', 0), Handler)
from tinygrad.viz.serve import TCPServerWithReuse cls.server.model = cls.mock_model
cls.server.model_name = "test-model"
cls.server = TCPServerWithReuse(('127.0.0.1', 0), Handler) cls.server.tok = cls.mock_tok
cls.server.bos_id = cls.bos_id
cls.server.eos_id = cls.eos_id
cls.server.eot_id = None
cls.port = cls.server.server_address[1] cls.port = cls.server.server_address[1]
cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True) cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True)
cls.server_thread.start() cls.server_thread.start()

View file

@ -1,5 +1,5 @@
import unittest, base64, functools, sys import unittest, base64, functools, sys
from tinygrad.apps.llm import SimpleTokenizer from tinygrad.llm.cli import SimpleTokenizer
from tinygrad.helpers import fetch from tinygrad.helpers import fetch
@unittest.skipIf(sys.platform == 'win32', "fetch race condition on Windows") @unittest.skipIf(sys.platform == 'win32', "fetch race condition on Windows")

View file

@ -1,7 +1,7 @@
import unittest import unittest
import numpy as np import numpy as np
from tinygrad import Tensor, dtypes from tinygrad import Tensor, dtypes
from tinygrad.apps.llm import ( from tinygrad.llm.cli import (
GatedDeltaNetBlock, SSMConfig, TransformerBlock, TransformerConfig, GatedDeltaNetBlock, SSMConfig, TransformerBlock, TransformerConfig,
apply_rope as apply_rope_new, precompute_freqs_cis, pairwise_topk, apply_rope as apply_rope_new, precompute_freqs_cis, pairwise_topk,
) )

View file

@ -1,7 +1,7 @@
import unittest import unittest
import numpy as np import numpy as np
from tinygrad import Tensor from tinygrad import Tensor
from tinygrad.apps.llm import Transformer, TransformerConfig, apply_rope from tinygrad.llm.cli import Transformer, TransformerConfig, apply_rope
class TestMLA(unittest.TestCase): class TestMLA(unittest.TestCase):
def _make_config(self, **kwargs): def _make_config(self, **kwargs):
@ -13,7 +13,7 @@ class TestMLA(unittest.TestCase):
def test_mla_attention_matches_naive(self): def test_mla_attention_matches_naive(self):
config = self._make_config(max_context=16) config = self._make_config(max_context=16)
from tinygrad.apps.llm import MLATransformerBlock, precompute_freqs_cis from tinygrad.llm.cli import MLATransformerBlock, precompute_freqs_cis
block = MLATransformerBlock(config) block = MLATransformerBlock(config)
c = config c = config

View file

@ -2,7 +2,7 @@ import unittest
import numpy as np import numpy as np
from dataclasses import replace from dataclasses import replace
from tinygrad import Tensor from tinygrad import Tensor
from tinygrad.apps.llm import TransformerBlock, TransformerConfig from tinygrad.llm.cli import TransformerBlock, TransformerConfig
def _moe_config(dim=8, hidden=16, n_heads=2, num_experts=4, num_experts_per_tok=2): def _moe_config(dim=8, hidden=16, n_heads=2, num_experts=4, num_experts_per_tok=2):
return TransformerConfig( return TransformerConfig(

View file

@ -2,7 +2,7 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
from tinygrad import Tensor, UOp from tinygrad import Tensor, UOp
from tinygrad.schedule import schedule_cache from tinygrad.schedule import schedule_cache
from tinygrad.apps.llm import Transformer, TransformerConfig from tinygrad.llm.cli import Transformer, TransformerConfig
TEST_CONFIG = TransformerConfig(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2, TEST_CONFIG = TransformerConfig(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, rope_dim=32, v_head_dim=32, max_context=32) norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, rope_dim=32, v_head_dim=32, max_context=32)

0
tinygrad/llm/__init__.py Normal file
View file

2
tinygrad/llm/__main__.py Normal file
View file

@ -0,0 +1,2 @@
from tinygrad.llm.cli import main
if __name__ == "__main__": main()

View file

@ -56,7 +56,7 @@ class SimpleTokenizer:
return tokens + self._encode_sentence(text[pos:]) return tokens + self._encode_sentence(text[pos:])
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode(errors='replace') def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode(errors='replace')
def stream_decoder(self) -> typing.Callable[[int|None], str]: def stream_decoder(self) -> typing.Callable[..., str]:
dec = codecs.getincrementaldecoder('utf-8')('replace') dec = codecs.getincrementaldecoder('utf-8')('replace')
def _decode(tid:int|None=None) -> str: return dec.decode(self._tok2bytes[tid]) if tid is not None else dec.decode(b'', final=True) def _decode(tid:int|None=None) -> str: return dec.decode(self._tok2bytes[tid]) if tid is not None else dec.decode(b'', final=True)
return _decode return _decode
@ -545,12 +545,23 @@ CHAT_HTML = b'''<!DOCTYPE html><html><head><title>tinygrad chat</title><style>
} }
</script></body></html>''' </script></body></html>'''
class LLMServer(TCPServerWithReuse):
model: Transformer
model_name: str
tok: SimpleTokenizer
# TODO: tastefully move these into tokenizer
bos_id: int|None
eos_id: int
eot_id: int|None
class Handler(HTTPRequestHandler): class Handler(HTTPRequestHandler):
server: LLMServer
def log_request(self, code='-', size='-'): pass def log_request(self, code='-', size='-'): pass
def do_GET(self): def do_GET(self):
if self.path == "/v1/models": self.send_data(json.dumps({"object":"list","data":[{"id":model_name,"object":"model"}]}).encode()) if self.path == "/v1/models": self.send_data(json.dumps({"object":"list","data":[{"id":self.server.model_name,"object":"model"}]}).encode())
else: self.send_data(CHAT_HTML, content_type="text/html") else: self.send_data(CHAT_HTML, content_type="text/html")
def run_model(self, ids:list[int], model_name:str, include_usage=False, max_tokens:int|None=None, temperature:float=0.0): def run_model(self, ids:list[int], model_name:str, include_usage=False, max_tokens:int|None=None, temperature:float=0.0):
model, tok, eos_id, eot_id = self.server.model, self.server.tok, self.server.eos_id, self.server.eot_id
cache_start_pos = model.get_start_pos(ids) cache_start_pos = model.get_start_pos(ids)
stderr_log(f"{self.path} {colored('--', 'BLACK')} " stderr_log(f"{self.path} {colored('--', 'BLACK')} "
f"in:{colored(f'{cache_start_pos:5d}', 'green')} +{len(ids)-cache_start_pos:5d} {colored('--', 'BLACK')} ") f"in:{colored(f'{cache_start_pos:5d}', 'green')} +{len(ids)-cache_start_pos:5d} {colored('--', 'BLACK')} ")
@ -577,6 +588,7 @@ class Handler(HTTPRequestHandler):
f"out:{len(out):5d} {colored('--', 'BLACK')} total:{et-st:6.2f}s\n") f"out:{len(out):5d} {colored('--', 'BLACK')} total:{et-st:6.2f}s\n")
def do_POST(self): def do_POST(self):
tok, bos_id, eos_id = self.server.tok, self.server.bos_id, self.server.eos_id
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0"))) raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8")) body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
if DEBUG >= 1: print(json.dumps(body, indent=2)) if DEBUG >= 1: print(json.dumps(body, indent=2))
@ -611,7 +623,7 @@ class Handler(HTTPRequestHandler):
else: else:
raise RuntimeError(f"unhandled path {self.path}") raise RuntimeError(f"unhandled path {self.path}")
if __name__ == "__main__": def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model", "-m", default=list(models.keys())[0], help=f"Model choice ({', '.join(models.keys())}) or path to a local GGUF file") parser.add_argument("--model", "-m", default=list(models.keys())[0], help=f"Model choice ({', '.join(models.keys())}) or path to a local GGUF file")
parser.add_argument("--max_context", type=int, default=4096, help="Max Context Length") parser.add_argument("--max_context", type=int, default=4096, help="Max Context Length")
@ -643,7 +655,11 @@ if __name__ == "__main__":
for _ in range(2): list(zip(range(2), model.generate([0]))) for _ in range(2): list(zip(range(2), model.generate([0])))
# start server # start server
if args.serve: TCPServerWithReuse(('', args.serve), Handler).serve_forever() if args.serve:
server = LLMServer(('', args.serve), Handler)
server.model, server.model_name, server.tok = model, model_name, tok
server.bos_id, server.eos_id, server.eot_id = bos_id, eos_id, eot_id
server.serve_forever()
# do benchmark # do benchmark
if args.benchmark is not None: if args.benchmark is not None:
@ -667,3 +683,5 @@ if __name__ == "__main__":
sys.stdout.write(dec(next_id) if next_id not in (eos_id, eot_id) else dec() + "\n\n") sys.stdout.write(dec(next_id) if next_id not in (eos_id, eot_id) else dec() + "\n\n")
sys.stdout.flush() sys.stdout.flush()
if next_id in (eos_id, eot_id): break if next_id in (eos_id, eot_id): break
if __name__ == "__main__": main()