Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
4fb38e94ad symbolic llm with prefill (ai slop) 2026-03-02 02:13:42 +00:00
5 changed files with 305 additions and 21 deletions

View file

@ -302,5 +302,62 @@ class TestSymbolicOps(unittest.TestCase):
expected = x_full[:, :, :val].conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
np.testing.assert_allclose(result[:, :, :13].numpy(), expected.numpy(), atol=1e-5, rtol=1e-5)
def test_triu_symbolic(self):
a = Tensor.rand(10, 10).realize()
for i in range(2, 6):
vi = Variable("i", 1, 10).bind(i)
symbolic = a[:vi, :vi].triu()
# extract concrete-sized result from symbolic output
symbolic_np = symbolic[:i, :i].numpy()
expected = a[:i, :i].triu().numpy()
np.testing.assert_allclose(symbolic_np, expected, atol=1e-6, rtol=1e-6)
def test_tril_symbolic(self):
a = Tensor.rand(10, 10).realize()
for i in range(2, 6):
vi = Variable("i", 1, 10).bind(i)
symbolic = a[:vi, :vi].tril()
symbolic_np = symbolic[:i, :i].numpy()
expected = a[:i, :i].tril().numpy()
np.testing.assert_allclose(symbolic_np, expected, atol=1e-6, rtol=1e-6)
def test_triu_symbolic_diagonal(self):
a = Tensor.rand(10, 10).realize()
for i in range(3, 6):
vi = Variable("i", 1, 10).bind(i)
for diag in [-1, 0, 1, 2]:
symbolic = a[:vi, :vi].triu(diagonal=diag)
symbolic_np = symbolic[:i, :i].numpy()
expected = a[:i, :i].triu(diagonal=diag).numpy()
np.testing.assert_allclose(symbolic_np, expected, atol=1e-6, rtol=1e-6)
def test_triu_symbolic_nonsquare(self):
a = Tensor.rand(10, 8).realize()
for i in range(2, 6):
vi = Variable("i", 1, 10).bind(i)
symbolic = a[:vi, :].triu()
symbolic_np = symbolic[:i, :].numpy()
expected = a[:i, :].triu().numpy()
np.testing.assert_allclose(symbolic_np, expected, atol=1e-6, rtol=1e-6)
def test_full_triu_symbolic(self):
"""Test the attention mask pattern: Tensor.full(symbolic_shape, -inf).triu(k)"""
for T in range(2, 6):
for sp in [0, 2, 5]:
vT = Variable("T", 1, 10).bind(T)
mask = Tensor.full((1, 1, vT, sp+vT), float("-inf")).triu(sp+1)
ref = Tensor.full((1, 1, T, sp+T), float("-inf")).triu(sp+1)
# compare element sums (counting -inf elements)
sym_finite = mask[:, :, :T, :sp+T].isnan().logical_not().cast(dtypes.int32).sum().item()
ref_finite = ref.isnan().logical_not().cast(dtypes.int32).sum().item()
self.assertEqual(sym_finite, ref_finite)
def test_arange_symbolic(self):
for i in range(1, 6):
vi = Variable("i", 1, 10).bind(i)
symbolic = Tensor.arange(vi)
expected = Tensor.arange(i)
np.testing.assert_allclose(symbolic[:i].numpy(), expected.numpy(), atol=1e-6, rtol=1e-6)
if __name__ == '__main__':
unittest.main()

View file

@ -6,13 +6,17 @@ class TestTransformerGenerate(unittest.TestCase):
def test_start_pos_parameter_is_used(self):
"""Test that start_pos parameter is not ignored (regression test for always resetting to 0)."""
from tinygrad.apps.llm import Transformer
from tinygrad.uop.ops import UOp
# Create a minimal transformer
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
captured_inputs = []
def mock_call(self, tokens, start_pos):
captured_inputs.append((tokens.shape, start_pos if isinstance(start_pos, int) else start_pos.bind_val))
s = tokens.shape[-1]
shape_val = s.val if isinstance(s, UOp) else s
sp_val = start_pos.val if isinstance(start_pos, UOp) else start_pos
captured_inputs.append((shape_val, sp_val))
return Tensor([[42]]) # return a fake next token
with patch.object(Transformer, '__call__', mock_call):
@ -22,8 +26,179 @@ class TestTransformerGenerate(unittest.TestCase):
# With start_pos=3, the initial tensor should only have tokens[3:] = [4, 5] (length 2)
# If the bug existed (start_pos always reset to 0), it would have all 5 tokens
self.assertEqual(captured_inputs[0][0][-1], 2) # shape should be (1, 2)
self.assertEqual(captured_inputs[0][0], 2) # shape should have 2 tokens
self.assertEqual(captured_inputs[0][1], 3) # start_pos should be 3, not 0
class TestPrefillCorrectness(unittest.TestCase):
@classmethod
def setUpClass(cls):
from tinygrad.apps.llm import Transformer
cls.model = Transformer(num_blocks=2, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=128)
def _reset_caches(self, model):
for block in model.blk:
if hasattr(block, "cache_kv"): del block.cache_kv
if hasattr(block, "causal_mask"): del block.causal_mask
def test_prefill_matches_forward(self):
"""Test that prefill_forward produces the same output as forward for various lengths."""
model = self.__class__.model
v_prefill = model.v_prefill
token_buf = Tensor.zeros(1, model.max_context, dtype="int32").contiguous().realize()
for length in [2, 4, 8, 16]:
tokens = Tensor.randint(length, high=100, dtype="int32")
token_buf[0, :length].assign(tokens).realize()
self._reset_caches(model)
expected = model.forward(tokens.reshape(1, -1), 0).realize()
self._reset_caches(model)
actual = model.prefill_forward(token_buf[:, :v_prefill.bind(length)], 0).realize()
self.assertEqual(expected.item(), actual.item(), f"mismatch at length={length}: forward={expected.item()}, prefill={actual.item()}")
def test_prefill_then_decode_matches_sequential(self):
"""Test that prefill + decode produces the same sequence as token-by-token forward."""
from tinygrad import UOp, getenv
model = self.__class__.model
tokens = Tensor.randint(6, high=100, dtype="int32").tolist()
num_decode = 3
# reference: feed all tokens one at a time, then decode
self._reset_caches(model)
ref_ids = list(tokens)
t = Tensor([ref_ids], dtype="int32")
t = model.forward(t, 0)
ref_ids.append(int(t.item()))
for i in range(num_decode - 1):
sp = len(ref_ids) - 1
t = model.forward(t, sp)
ref_ids.append(int(t.item()))
# test: prefill then decode
self._reset_caches(model)
test_ids = list(tokens)
token_buf = Tensor.zeros(1, model.max_context, dtype="int32").contiguous().realize()
token_buf[0, :len(test_ids)].assign(Tensor(test_ids, dtype="int32")).realize()
t = model.prefill_forward(token_buf[:, :model.v_prefill.bind(len(test_ids))], 0)
test_ids.append(int(t.item()))
for i in range(num_decode - 1):
sp = len(test_ids) - 1
t = model.forward(t, sp)
test_ids.append(int(t.item()))
self.assertEqual(ref_ids, test_ids, f"sequence mismatch: ref={ref_ids[-num_decode:]}, test={test_ids[-num_decode:]}")
def test_prefill_with_start_pos(self):
"""Test that prefill with start_pos > 0 matches sequential forward."""
model = self.__class__.model
first_tokens = Tensor.randint(4, high=100, dtype="int32").tolist()
second_tokens = Tensor.randint(3, high=100, dtype="int32").tolist()
all_tokens = first_tokens + second_tokens
# reference: forward all at once from pos 0
self._reset_caches(model)
expected = model.forward(Tensor([all_tokens], dtype="int32"), 0).realize()
# test: forward first chunk, then prefill second chunk at start_pos
self._reset_caches(model)
model.forward(Tensor([first_tokens], dtype="int32"), 0).realize()
token_buf = Tensor.zeros(1, model.max_context, dtype="int32").contiguous().realize()
token_buf[0, :len(second_tokens)].assign(Tensor(second_tokens, dtype="int32")).realize()
actual = model.prefill_forward(token_buf[:, :model.v_prefill.bind(len(second_tokens))],
model.v_start_pos.bind(len(first_tokens))).realize()
self.assertEqual(expected.item(), actual.item(),
f"start_pos mismatch: all-at-once={expected.item()}, split={actual.item()}")
def test_multi_turn_generate(self):
"""Test that multi-turn generate produces same tokens as reference (token-by-token forward)."""
model = self.__class__.model
num_decode = 3
# simulate 2 turns of conversation
turn1_tokens = Tensor.randint(5, high=100, dtype="int32").tolist()
turn2_tokens = Tensor.randint(4, high=100, dtype="int32").tolist()
# reference: token-by-token forward (no prefill, no JIT)
self._reset_caches(model)
ref_ids = list(turn1_tokens)
# prefill turn 1
t = model.forward(Tensor([ref_ids], dtype="int32"), 0)
ref_ids.append(int(t.item()))
# decode turn 1
for _ in range(num_decode - 1):
t = model.forward(t, len(ref_ids) - 1)
ref_ids.append(int(t.item()))
# prefill turn 2
start_pos_2 = len(ref_ids)
ref_ids += turn2_tokens
t = model.forward(Tensor([ref_ids[start_pos_2:]], dtype="int32"), start_pos_2)
ref_ids.append(int(t.item()))
# decode turn 2
for _ in range(num_decode - 1):
t = model.forward(t, len(ref_ids) - 1)
ref_ids.append(int(t.item()))
# test: using prefill_forward with symbolic variable
self._reset_caches(model)
test_ids = list(turn1_tokens)
token_buf = Tensor.zeros(1, model.max_context, dtype="int32").contiguous().realize()
# prefill turn 1
token_buf[0, :len(test_ids)].assign(Tensor(test_ids, dtype="int32")).realize()
t = model.prefill_forward(token_buf[:, :model.v_prefill.bind(len(test_ids))], 0)
test_ids.append(int(t.item()))
# decode turn 1
for _ in range(num_decode - 1):
t = model.forward(t, len(test_ids) - 1)
test_ids.append(int(t.item()))
# prefill turn 2
start_pos_2 = len(test_ids)
test_ids += turn2_tokens
new_tokens = test_ids[start_pos_2:]
token_buf[0, :len(new_tokens)].assign(Tensor(new_tokens, dtype="int32")).realize()
t = model.prefill_forward(token_buf[:, :model.v_prefill.bind(len(new_tokens))],
model.v_start_pos.bind(start_pos_2))
test_ids.append(int(t.item()))
# decode turn 2
for _ in range(num_decode - 1):
t = model.forward(t, len(test_ids) - 1)
test_ids.append(int(t.item()))
self.assertEqual(ref_ids, test_ids, f"multi-turn mismatch:\nref ={ref_ids}\ntest={test_ids}")
def test_generate_matches_forward(self):
"""Test that generate() (with symbolic prefill + JIT) matches token-by-token forward for 3 turns."""
from tinygrad import TinyJit
model = self.__class__.model
num_decode = 3
turns = [Tensor.randint(5, high=100, dtype="int32").tolist() for _ in range(3)]
# reference: token-by-token forward
self._reset_caches(model)
ref_ids: list[int] = []
for turn_tokens in turns:
start_pos = len(ref_ids)
ref_ids += turn_tokens
t = model.forward(Tensor([ref_ids[start_pos:]], dtype="int32"), start_pos)
ref_ids.append(int(t.item()))
for _ in range(num_decode - 1):
t = model.forward(t, len(ref_ids) - 1)
ref_ids.append(int(t.item()))
# test: using generate() with symbolic prefill
self._reset_caches(model)
model.forward_jit = TinyJit(model.forward)
model.prefill_jit = TinyJit(model.prefill_forward)
test_ids: list[int] = []
for turn_tokens in turns:
start_pos = len(test_ids)
test_ids += turn_tokens
gen = model.generate(test_ids, start_pos)
for _ in range(num_decode): next(gen)
self.assertEqual(ref_ids, test_ids, f"generate mismatch:\nref ={ref_ids}\ntest={test_ids}")
if __name__ == '__main__':
unittest.main()

View file

@ -1,5 +1,5 @@
from __future__ import annotations
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools, array
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
@ -116,8 +116,7 @@ class TransformerBlock:
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
@function
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
def _attention_inner(self, x:Tensor, start_pos:int|UOp, mask:Tensor|None) -> Tensor:
x_norm = self.attn_norm(x) # (B,T,D)
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
if self.qk_norm and self.qk_norm != self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
@ -142,13 +141,18 @@ class TransformerBlock:
#k = self.cache_kv[0, :, :, 0:start_pos+T, :]
#v = self.cache_kv[1, :, :, 0:start_pos+T, :]
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(int(start_pos)+1) if T > 1 else None
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
attn = self.attn_output(attn)
return x + attn
@function
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
B, T, _ = x.shape
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(int(start_pos)+1) if T > 1 else None
return self._attention_inner(x, start_pos, mask)
@function
def _feed_forward(self, h: Tensor) -> Tensor:
h_norm = self.ffn_norm(h)
@ -167,6 +171,16 @@ class TransformerBlock:
self.cache_kv = Tensor.zeros(2, x.shape[0], self.n_kv_heads, self.max_context, self.head_dim, device=x.device).contiguous().realize()
return self._feed_forward(self._attention(x, start_pos)).contiguous()
def prefill(self, x: Tensor, start_pos: int|UOp):
"""Like __call__ but without @function wrappers, for JIT-compatible symbolic prefill."""
if not hasattr(self, "cache_kv"):
self.cache_kv = Tensor.zeros(2, x.shape[0], self.n_kv_heads, self.max_context, self.head_dim, device=x.device).contiguous().realize()
if not hasattr(self, "causal_mask"):
self.causal_mask = Tensor.full((1, 1, self.max_context, self.max_context), float("-inf"), dtype=x.dtype, device=x.device).triu(1).contiguous().realize()
T = x.shape[1]
mask = self.causal_mask[:, :, start_pos:start_pos+T, :start_pos+T] if (T > 1 if isinstance(T, int) else T.vmin > 1) else None
return TransformerBlock._feed_forward.fxn(self, self._attention_inner(x, start_pos, mask)).contiguous()
class Transformer:
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float,
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
@ -176,8 +190,10 @@ class Transformer:
self.output_norm = nn.RMSNorm(dim, norm_eps)
self.output = nn.Linear(dim, vocab_size, bias=False)
self.max_context = max_context
# JIT is used if T=1 and start_pos is a UOp. TODO: make this not needed by including T in the JIT and making start_pos always a UOp
self.forward_jit = TinyJit(self.forward)
self.prefill_jit = TinyJit(self.prefill_forward)
self.v_start_pos = UOp.variable("start_pos", 1, max_context-1)
self.v_prefill = UOp.variable("prefill", 2, max_context)
def forward(self, tokens:Tensor, start_pos:int|UOp) -> Tensor:
x = self.token_embd(tokens) # (B, T, D)
@ -186,7 +202,9 @@ class Transformer:
return self.output(self.output_norm(x))[:, -1, :].softmax(-1, dtype="float").argmax(-1, keepdim=True)
def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor:
return (self.forward_jit if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp) else self.forward)(tokens, start_pos)
if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp): return self.forward_jit(tokens, start_pos)
if getenv("JIT", 1) and tokens.shape[1] != 1 and isinstance(tokens.shape[1], UOp): return self.prefill_jit(tokens, start_pos)
return self.forward(tokens, start_pos)
@staticmethod
def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALIZE", 1))) -> tuple[Transformer, dict]:
@ -224,11 +242,33 @@ class Transformer:
Tensor.realize(*params)
return model, kv
def prefill_forward(self, tokens:Tensor, start_pos:int|UOp) -> Tensor:
"""Like forward but without @function on attention/ffn, for symbolic prefill."""
x = self.token_embd(tokens)
for block in self.blk: x = block.prefill(x, start_pos)
return self.output(self.output_norm(x))[:, -1, :].softmax(-1, dtype="float").argmax(-1, keepdim=True)
def generate(self, tokens:list[int], start_pos=0):
v_start_pos = UOp.variable("start_pos", 1, self.max_context-1)
t = Tensor([tokens[start_pos:]], dtype="int32")
if not hasattr(self, "_prefill_buf"):
self._prefill_buf = Tensor.zeros(1, self.max_context, dtype="int32").contiguous().realize()
self._prefill_staging = array.array('i', [0] * self.max_context)
# prefill all tokens at once using symbolic variable
num_tokens = len(tokens) - start_pos
if num_tokens >= 2 and getenv("SYM", 1):
# write tokens directly into the buffer via copyin (no scheduled kernels, no cache misses)
for i, t in enumerate(tokens[start_pos:]): self._prefill_staging[i] = t
self._prefill_buf._buffer().copyin(memoryview(self._prefill_staging))
v_sp = self.v_start_pos.bind(start_pos) if start_pos != 0 else start_pos
t = self(self._prefill_buf[:, :self.v_prefill.bind(num_tokens)], v_sp)
next_id = int(t.item())
tokens.append(next_id)
start_pos = len(tokens) - 1
yield next_id
else:
t = Tensor([tokens[start_pos:]], dtype="int32")
# decode one token at a time
while len(tokens) < self.max_context:
t = self(t, v_start_pos.bind(start_pos) if getenv("SYM", 1) and start_pos != 0 and t.shape[-1] == 1 else start_pos)
t = self(t, self.v_start_pos.bind(start_pos) if getenv("SYM", 1) and start_pos != 0 and t.shape[-1] == 1 else start_pos)
next_id = int(t.item())
tokens.append(next_id)
start_pos = len(tokens) - 1
@ -356,6 +396,19 @@ if __name__ == "__main__":
# do benchmark
if args.benchmark:
# prefill benchmark
token_buf = Tensor.zeros(1, args.max_context, dtype="int32").contiguous().realize()
print("prefill benchmark:")
for length in [4, 16, 64, 256, 1024]:
if length > args.max_context: break
GlobalCounters.reset()
with Timing(prefix=f" prefill {length:4d}: ",
on_exit=lambda x, length=length: f", {length*1e9/x:6.2f} tok/s, {GlobalCounters.global_mem/x:7.2f} GB/s,"
f" {GlobalCounters.global_mem//1000000}/{GlobalCounters.mem_used//1000000} MB"):
model.prefill_jit(token_buf[:, :model.v_prefill.bind(length)], 0).realize()
# generation benchmark
print("generation benchmark:")
gen = model.generate([0], 0)
for _ in range(args.benchmark):
GlobalCounters.reset()

View file

@ -4,12 +4,11 @@ from tinygrad.helpers import all_int, dedup, get_contraction
from tinygrad.dtype import dtypes, AddrSpace, Invalid
from tinygrad.renderer import Renderer
def _dim_max(d:sint) -> int: return d if isinstance(d, int) else d.vmax
def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
# TODO: symbolic shape
if not all_int(dims): return dims
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
while len(dims) > len(max_sizes) or any(_dim_max(d) > m for d,m in zip(dims, max_sizes)):
for i,m in enumerate(max_sizes):
if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
if i < (len(dims)-1) and _dim_max(dims[i]) * _dim_max(dims[i+1]) <= m:
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
break
else: return None

View file

@ -722,9 +722,10 @@ class Tensor(OpMixin):
"""
if stop is None: stop, start = start, 0
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
if start < (dt:=to_dtype(dtype)).min or dt.max < (stop-step): raise ValueError(f"arange [{start}, {stop}) is not representable in dtype {dtype}")
if resolve(start < (dt:=to_dtype(dtype)).min, False) or resolve(dt.max < (stop-step), False):
raise ValueError(f"arange [{start}, {stop}) is not representable in dtype {dtype}")
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
if resolve((output_len:=ceildiv(stop-start, step)) <= 0, False): return Tensor([], dtype=dtype, **kwargs)
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
@staticmethod
@ -2566,7 +2567,6 @@ class Tensor(OpMixin):
@staticmethod
def _tri(r:sint, c:sint, diagonal:int=0, device=None, requires_grad:bool|None=None) -> Tensor:
assert isinstance(r, int) and isinstance(c, int), f"does not support symbolic, getting {r=}, {c=}"
return (Tensor.arange(r, device=device).unsqueeze(-1) + diagonal <= Tensor.arange(c, device=device)).requires_grad_(requires_grad)
def triu(self, diagonal:int=0) -> Tensor:
@ -3284,8 +3284,8 @@ class Tensor(OpMixin):
print(q.scaled_dot_product_attention(k, v).numpy())
```
"""
# NOTE: it also works when `key` and `value` have symbolic shape.
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
# NOTE: only n_heads (dim -3) and head_dim (dim -1) must be concrete; batch and sequence length dims can be symbolic.
assert all_int(self.shape[-1:]) and all_int(self.shape[-3:-2]), f"does not support symbolic shape {self.shape}"
if getenv("FLASH_ATTENTION"):
from extra.thunder.tiny.fa import flash_attention