Compare commits

...

5 commits

Author SHA1 Message Date
George Hotz
63c16143fc fix test attention 2025-12-14 20:45:19 -05:00
George Hotz
d4385d62d3 Merge remote-tracking branch 'origin/master' into fast_llm 2025-12-14 20:23:50 -05:00
George Hotz
257f683778 debug logging 2025-12-14 14:39:34 -05:00
George Hotz
afbbcdb530 fix speed 2025-12-14 13:29:36 -05:00
George Hotz
11b57898f2 llm: add --benchmark support 2025-12-14 08:09:58 -05:00
3 changed files with 38 additions and 19 deletions

View file

@ -1,8 +1,14 @@
import unittest import unittest
from tinygrad import Tensor, dtypes, TinyJit, UOp from tinygrad import Tensor, dtypes, TinyJit, UOp
from tinygrad.apps.llm import apply_rope from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis
#from tinygrad.engine.realize import run_schedule #from tinygrad.engine.realize import run_schedule
def apply_rope(x:Tensor, start_pos:int):
B, H, T, Hd = x.shape
precompute_freqs_cis.cache_clear()
freqs_cis = precompute_freqs_cis(Hd, start_pos+T)[start_pos:start_pos+T]
return apply_rope_new(x, freqs_cis)
# TODO: test_scheduler, but just in uint # TODO: test_scheduler, but just in uint
class TestAttention(unittest.TestCase): class TestAttention(unittest.TestCase):
def test_half_qkv_buffers(self): def test_half_qkv_buffers(self):
@ -39,7 +45,7 @@ class TestAttention(unittest.TestCase):
prune_size = len(rope_prune.captured.jit_cache) prune_size = len(rope_prune.captured.jit_cache)
self.assertGreater(noprune_size, prune_size) self.assertGreater(noprune_size, prune_size)
self.assertGreaterEqual(noprune_size, 3) self.assertGreaterEqual(noprune_size, 2)
self.assertEqual(prune_size, 1) self.assertEqual(prune_size, 1)
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import sys, argparse, typing, re, unicodedata, json, uuid, time import sys, argparse, typing, re, unicodedata, json, uuid, time, functools
from tinygrad import Tensor, nn, UOp, TinyJit, getenv from tinygrad import Tensor, nn, UOp, TinyJit, getenv
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, tqdm, DEBUG, Timing, GlobalCounters from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, DEBUG, Timing, GlobalCounters, stderr_log, colored
class SimpleTokenizer: class SimpleTokenizer:
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int]): def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int]):
@ -52,15 +52,18 @@ class SimpleTokenizer:
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode() def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode()
def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n") def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
def apply_rope(x:Tensor, start_pos:int|UOp, base:float = 10000.0) -> Tensor: @functools.cache
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).contiguous()
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
B, H, T, Hd = x.shape B, H, T, Hd = x.shape
assert isinstance(Hd, int) and (Hd & 1) == 0, "RoPE requires an even head dimension" assert isinstance(Hd, int) and (Hd & 1) == 0, "RoPE requires an even head dimension"
half = Hd // 2 x_pairs = x.reshape(B, H, T, Hd//2, 2)
t_start_pos = start_pos if isinstance(start_pos, int) else Tensor(start_pos) cos = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 0]
angles = (Tensor.arange(T, dtype="float32") + t_start_pos)[:, None] * (base ** (-(Tensor.arange(half, dtype="float32") / half)))[None, :] sin = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 1]
# contiguous here allows RoPE to be pruned in the JIT
cos, sin = angles.cos().reshape(1, 1, T, half).cast(x.dtype).contiguous(), angles.sin().reshape(1, 1, T, half).cast(x.dtype).contiguous()
x_pairs = x.reshape(B, H, T, half, 2)
return Tensor.stack(x_pairs[..., 0] * cos - x_pairs[..., 1] * sin, return Tensor.stack(x_pairs[..., 0] * cos - x_pairs[..., 1] * sin,
x_pairs[..., 0] * sin + x_pairs[..., 1] * cos, dim=-1).reshape(B, H, T, Hd) x_pairs[..., 0] * sin + x_pairs[..., 1] * cos, dim=-1).reshape(B, H, T, Hd)
@ -96,8 +99,10 @@ class TransformerBlock:
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd) k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd) v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
q = apply_rope(q, start_pos) # TODO: make UOp have SupportsIndex
k = apply_rope(k, start_pos) freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context)[start_pos:start_pos+T] # type: ignore
q = apply_rope(q, freqs_cis)
k = apply_rope(k, freqs_cis)
# TODO: remove these kv cache realizes # TODO: remove these kv cache realizes
if not hasattr(self, "cache_kv"): if not hasattr(self, "cache_kv"):
@ -115,7 +120,8 @@ class TransformerBlock:
def _feed_forward(self, h: Tensor) -> Tensor: def _feed_forward(self, h: Tensor) -> Tensor:
h_norm = self.ffn_norm(h) h_norm = self.ffn_norm(h)
gated = self.ffn_gate(h_norm).silu() * self.ffn_up(h_norm) # TODO: remove the need for this contiguous
gated = self.ffn_gate(h_norm).silu().contiguous() * self.ffn_up(h_norm)
return h + self.ffn_down(gated) return h + self.ffn_down(gated)
def __call__(self, x: Tensor, start_pos: int|UOp): def __call__(self, x: Tensor, start_pos: int|UOp):
@ -185,24 +191,27 @@ models = {
# OPENAI_BASE_URL=http://localhost:11434/v1 OPENAI_API_KEY=ollama uvx --from gpt-command-line gpt # OPENAI_BASE_URL=http://localhost:11434/v1 OPENAI_API_KEY=ollama uvx --from gpt-command-line gpt
class Handler(HTTPRequestHandler): class Handler(HTTPRequestHandler):
def log_request(self, code='-', size='-'): pass
def run_model(self, ids:list[int], model_name:str, include_usage=False): def run_model(self, ids:list[int], model_name:str, include_usage=False):
stderr_log(f"{self.path} {colored('--', 'BLACK')} in:{len(ids):5d} {colored('--', 'BLACK')} ")
tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk", "created":int(time.time()), "model":model_name} tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk", "created":int(time.time()), "model":model_name}
yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl} yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl}
out = [] out: list[int] = []
for next_id in tqdm(model.generate(ids), disable=not DEBUG>=1): st = time.perf_counter()
for next_id in model.generate(ids):
if len(out) == 0: stderr_log(f"prefill:{len(ids)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ")
if next_id == eos_id: break if next_id == eos_id: break
out.append(next_id) out.append(next_id)
yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl} yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl}
yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl} yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl}
if include_usage: if include_usage:
yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}, **tmpl} yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}, **tmpl}
stderr_log(f"out:{len(out):5d} {colored('--', 'BLACK')} gen: {len(out)/(time.perf_counter()-pt):4.0f} tok/s\n")
def do_POST(self): def do_POST(self):
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0"))) raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8")) body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
if DEBUG >= 1: if DEBUG >= 1: print(json.dumps(body, indent=2))
print(self.path)
print(json.dumps(body, indent=2))
if self.path == "/v1/chat/completions": if self.path == "/v1/chat/completions":
# extract tokens # extract tokens
ids = [bos_id] ids = [bos_id]

View file

@ -149,6 +149,10 @@ def getenv(key:str, default:Any=0): return type(default)(os.getenv(key, default)
def temp(x:str, append_user:bool=False) -> str: def temp(x:str, append_user:bool=False) -> str:
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix() return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
def stderr_log(msg):
sys.stderr.write(msg)
sys.stderr.flush()
class Context(contextlib.ContextDecorator): class Context(contextlib.ContextDecorator):
def __init__(self, **kwargs): self.kwargs = kwargs def __init__(self, **kwargs): self.kwargs = kwargs
def __enter__(self): def __enter__(self):