Compare commits

...

5 commits

Author SHA1 Message Date
George Hotz
63c16143fc fix test attention 2025-12-14 20:45:19 -05:00
George Hotz
d4385d62d3 Merge remote-tracking branch 'origin/master' into fast_llm 2025-12-14 20:23:50 -05:00
George Hotz
257f683778 debug logging 2025-12-14 14:39:34 -05:00
George Hotz
afbbcdb530 fix speed 2025-12-14 13:29:36 -05:00
George Hotz
11b57898f2 llm: add --benchmark support 2025-12-14 08:09:58 -05:00
3 changed files with 38 additions and 19 deletions

View file

@ -1,8 +1,14 @@
import unittest
from tinygrad import Tensor, dtypes, TinyJit, UOp
from tinygrad.apps.llm import apply_rope
from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis
#from tinygrad.engine.realize import run_schedule
def apply_rope(x:Tensor, start_pos:int):
B, H, T, Hd = x.shape
precompute_freqs_cis.cache_clear()
freqs_cis = precompute_freqs_cis(Hd, start_pos+T)[start_pos:start_pos+T]
return apply_rope_new(x, freqs_cis)
# TODO: test_scheduler, but just in uint
class TestAttention(unittest.TestCase):
def test_half_qkv_buffers(self):
@ -39,7 +45,7 @@ class TestAttention(unittest.TestCase):
prune_size = len(rope_prune.captured.jit_cache)
self.assertGreater(noprune_size, prune_size)
self.assertGreaterEqual(noprune_size, 3)
self.assertGreaterEqual(noprune_size, 2)
self.assertEqual(prune_size, 1)
if __name__ == '__main__':

View file

@ -1,7 +1,7 @@
from __future__ import annotations
import sys, argparse, typing, re, unicodedata, json, uuid, time
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, tqdm, DEBUG, Timing, GlobalCounters
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, DEBUG, Timing, GlobalCounters, stderr_log, colored
class SimpleTokenizer:
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int]):
@ -52,15 +52,18 @@ class SimpleTokenizer:
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode()
def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
def apply_rope(x:Tensor, start_pos:int|UOp, base:float = 10000.0) -> Tensor:
@functools.cache
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).contiguous()
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
B, H, T, Hd = x.shape
assert isinstance(Hd, int) and (Hd & 1) == 0, "RoPE requires an even head dimension"
half = Hd // 2
t_start_pos = start_pos if isinstance(start_pos, int) else Tensor(start_pos)
angles = (Tensor.arange(T, dtype="float32") + t_start_pos)[:, None] * (base ** (-(Tensor.arange(half, dtype="float32") / half)))[None, :]
# contiguous here allows RoPE to be pruned in the JIT
cos, sin = angles.cos().reshape(1, 1, T, half).cast(x.dtype).contiguous(), angles.sin().reshape(1, 1, T, half).cast(x.dtype).contiguous()
x_pairs = x.reshape(B, H, T, half, 2)
x_pairs = x.reshape(B, H, T, Hd//2, 2)
cos = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 0]
sin = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 1]
return Tensor.stack(x_pairs[..., 0] * cos - x_pairs[..., 1] * sin,
x_pairs[..., 0] * sin + x_pairs[..., 1] * cos, dim=-1).reshape(B, H, T, Hd)
@ -96,8 +99,10 @@ class TransformerBlock:
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
q = apply_rope(q, start_pos)
k = apply_rope(k, start_pos)
# TODO: make UOp have SupportsIndex
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context)[start_pos:start_pos+T] # type: ignore
q = apply_rope(q, freqs_cis)
k = apply_rope(k, freqs_cis)
# TODO: remove these kv cache realizes
if not hasattr(self, "cache_kv"):
@ -115,7 +120,8 @@ class TransformerBlock:
def _feed_forward(self, h: Tensor) -> Tensor:
h_norm = self.ffn_norm(h)
gated = self.ffn_gate(h_norm).silu() * self.ffn_up(h_norm)
# TODO: remove the need for this contiguous
gated = self.ffn_gate(h_norm).silu().contiguous() * self.ffn_up(h_norm)
return h + self.ffn_down(gated)
def __call__(self, x: Tensor, start_pos: int|UOp):
@ -185,24 +191,27 @@ models = {
# OPENAI_BASE_URL=http://localhost:11434/v1 OPENAI_API_KEY=ollama uvx --from gpt-command-line gpt
class Handler(HTTPRequestHandler):
def log_request(self, code='-', size='-'): pass
def run_model(self, ids:list[int], model_name:str, include_usage=False):
stderr_log(f"{self.path} {colored('--', 'BLACK')} in:{len(ids):5d} {colored('--', 'BLACK')} ")
tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk", "created":int(time.time()), "model":model_name}
yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl}
out = []
for next_id in tqdm(model.generate(ids), disable=not DEBUG>=1):
out: list[int] = []
st = time.perf_counter()
for next_id in model.generate(ids):
if len(out) == 0: stderr_log(f"prefill:{len(ids)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ")
if next_id == eos_id: break
out.append(next_id)
yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl}
yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl}
if include_usage:
yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}, **tmpl}
stderr_log(f"out:{len(out):5d} {colored('--', 'BLACK')} gen: {len(out)/(time.perf_counter()-pt):4.0f} tok/s\n")
def do_POST(self):
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
if DEBUG >= 1:
print(self.path)
print(json.dumps(body, indent=2))
if DEBUG >= 1: print(json.dumps(body, indent=2))
if self.path == "/v1/chat/completions":
# extract tokens
ids = [bos_id]

View file

@ -149,6 +149,10 @@ def getenv(key:str, default:Any=0): return type(default)(os.getenv(key, default)
def temp(x:str, append_user:bool=False) -> str:
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
def stderr_log(msg):
sys.stderr.write(msg)
sys.stderr.flush()
class Context(contextlib.ContextDecorator):
def __init__(self, **kwargs): self.kwargs = kwargs
def __enter__(self):