mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fb38e94ad |
5 changed files with 305 additions and 21 deletions
|
|
@ -302,5 +302,62 @@ class TestSymbolicOps(unittest.TestCase):
|
|||
expected = x_full[:, :, :val].conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
|
||||
np.testing.assert_allclose(result[:, :, :13].numpy(), expected.numpy(), atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_triu_symbolic(self):
|
||||
a = Tensor.rand(10, 10).realize()
|
||||
for i in range(2, 6):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = a[:vi, :vi].triu()
|
||||
# extract concrete-sized result from symbolic output
|
||||
symbolic_np = symbolic[:i, :i].numpy()
|
||||
expected = a[:i, :i].triu().numpy()
|
||||
np.testing.assert_allclose(symbolic_np, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_tril_symbolic(self):
|
||||
a = Tensor.rand(10, 10).realize()
|
||||
for i in range(2, 6):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = a[:vi, :vi].tril()
|
||||
symbolic_np = symbolic[:i, :i].numpy()
|
||||
expected = a[:i, :i].tril().numpy()
|
||||
np.testing.assert_allclose(symbolic_np, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_triu_symbolic_diagonal(self):
|
||||
a = Tensor.rand(10, 10).realize()
|
||||
for i in range(3, 6):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
for diag in [-1, 0, 1, 2]:
|
||||
symbolic = a[:vi, :vi].triu(diagonal=diag)
|
||||
symbolic_np = symbolic[:i, :i].numpy()
|
||||
expected = a[:i, :i].triu(diagonal=diag).numpy()
|
||||
np.testing.assert_allclose(symbolic_np, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_triu_symbolic_nonsquare(self):
|
||||
a = Tensor.rand(10, 8).realize()
|
||||
for i in range(2, 6):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = a[:vi, :].triu()
|
||||
symbolic_np = symbolic[:i, :].numpy()
|
||||
expected = a[:i, :].triu().numpy()
|
||||
np.testing.assert_allclose(symbolic_np, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_full_triu_symbolic(self):
|
||||
"""Test the attention mask pattern: Tensor.full(symbolic_shape, -inf).triu(k)"""
|
||||
for T in range(2, 6):
|
||||
for sp in [0, 2, 5]:
|
||||
vT = Variable("T", 1, 10).bind(T)
|
||||
mask = Tensor.full((1, 1, vT, sp+vT), float("-inf")).triu(sp+1)
|
||||
ref = Tensor.full((1, 1, T, sp+T), float("-inf")).triu(sp+1)
|
||||
# compare element sums (counting -inf elements)
|
||||
sym_finite = mask[:, :, :T, :sp+T].isnan().logical_not().cast(dtypes.int32).sum().item()
|
||||
ref_finite = ref.isnan().logical_not().cast(dtypes.int32).sum().item()
|
||||
self.assertEqual(sym_finite, ref_finite)
|
||||
|
||||
def test_arange_symbolic(self):
|
||||
for i in range(1, 6):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = Tensor.arange(vi)
|
||||
expected = Tensor.arange(i)
|
||||
np.testing.assert_allclose(symbolic[:i].numpy(), expected.numpy(), atol=1e-6, rtol=1e-6)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -6,13 +6,17 @@ class TestTransformerGenerate(unittest.TestCase):
|
|||
def test_start_pos_parameter_is_used(self):
|
||||
"""Test that start_pos parameter is not ignored (regression test for always resetting to 0)."""
|
||||
from tinygrad.apps.llm import Transformer
|
||||
from tinygrad.uop.ops import UOp
|
||||
# Create a minimal transformer
|
||||
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
|
||||
|
||||
captured_inputs = []
|
||||
def mock_call(self, tokens, start_pos):
|
||||
captured_inputs.append((tokens.shape, start_pos if isinstance(start_pos, int) else start_pos.bind_val))
|
||||
s = tokens.shape[-1]
|
||||
shape_val = s.val if isinstance(s, UOp) else s
|
||||
sp_val = start_pos.val if isinstance(start_pos, UOp) else start_pos
|
||||
captured_inputs.append((shape_val, sp_val))
|
||||
return Tensor([[42]]) # return a fake next token
|
||||
|
||||
with patch.object(Transformer, '__call__', mock_call):
|
||||
|
|
@ -22,8 +26,179 @@ class TestTransformerGenerate(unittest.TestCase):
|
|||
|
||||
# With start_pos=3, the initial tensor should only have tokens[3:] = [4, 5] (length 2)
|
||||
# If the bug existed (start_pos always reset to 0), it would have all 5 tokens
|
||||
self.assertEqual(captured_inputs[0][0][-1], 2) # shape should be (1, 2)
|
||||
self.assertEqual(captured_inputs[0][0], 2) # shape should have 2 tokens
|
||||
self.assertEqual(captured_inputs[0][1], 3) # start_pos should be 3, not 0
|
||||
|
||||
class TestPrefillCorrectness(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
from tinygrad.apps.llm import Transformer
|
||||
cls.model = Transformer(num_blocks=2, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=128)
|
||||
|
||||
def _reset_caches(self, model):
|
||||
for block in model.blk:
|
||||
if hasattr(block, "cache_kv"): del block.cache_kv
|
||||
if hasattr(block, "causal_mask"): del block.causal_mask
|
||||
|
||||
def test_prefill_matches_forward(self):
|
||||
"""Test that prefill_forward produces the same output as forward for various lengths."""
|
||||
model = self.__class__.model
|
||||
v_prefill = model.v_prefill
|
||||
token_buf = Tensor.zeros(1, model.max_context, dtype="int32").contiguous().realize()
|
||||
for length in [2, 4, 8, 16]:
|
||||
tokens = Tensor.randint(length, high=100, dtype="int32")
|
||||
token_buf[0, :length].assign(tokens).realize()
|
||||
|
||||
self._reset_caches(model)
|
||||
expected = model.forward(tokens.reshape(1, -1), 0).realize()
|
||||
|
||||
self._reset_caches(model)
|
||||
actual = model.prefill_forward(token_buf[:, :v_prefill.bind(length)], 0).realize()
|
||||
|
||||
self.assertEqual(expected.item(), actual.item(), f"mismatch at length={length}: forward={expected.item()}, prefill={actual.item()}")
|
||||
|
||||
def test_prefill_then_decode_matches_sequential(self):
|
||||
"""Test that prefill + decode produces the same sequence as token-by-token forward."""
|
||||
from tinygrad import UOp, getenv
|
||||
model = self.__class__.model
|
||||
tokens = Tensor.randint(6, high=100, dtype="int32").tolist()
|
||||
num_decode = 3
|
||||
|
||||
# reference: feed all tokens one at a time, then decode
|
||||
self._reset_caches(model)
|
||||
ref_ids = list(tokens)
|
||||
t = Tensor([ref_ids], dtype="int32")
|
||||
t = model.forward(t, 0)
|
||||
ref_ids.append(int(t.item()))
|
||||
for i in range(num_decode - 1):
|
||||
sp = len(ref_ids) - 1
|
||||
t = model.forward(t, sp)
|
||||
ref_ids.append(int(t.item()))
|
||||
|
||||
# test: prefill then decode
|
||||
self._reset_caches(model)
|
||||
test_ids = list(tokens)
|
||||
token_buf = Tensor.zeros(1, model.max_context, dtype="int32").contiguous().realize()
|
||||
token_buf[0, :len(test_ids)].assign(Tensor(test_ids, dtype="int32")).realize()
|
||||
t = model.prefill_forward(token_buf[:, :model.v_prefill.bind(len(test_ids))], 0)
|
||||
test_ids.append(int(t.item()))
|
||||
for i in range(num_decode - 1):
|
||||
sp = len(test_ids) - 1
|
||||
t = model.forward(t, sp)
|
||||
test_ids.append(int(t.item()))
|
||||
|
||||
self.assertEqual(ref_ids, test_ids, f"sequence mismatch: ref={ref_ids[-num_decode:]}, test={test_ids[-num_decode:]}")
|
||||
|
||||
def test_prefill_with_start_pos(self):
|
||||
"""Test that prefill with start_pos > 0 matches sequential forward."""
|
||||
model = self.__class__.model
|
||||
first_tokens = Tensor.randint(4, high=100, dtype="int32").tolist()
|
||||
second_tokens = Tensor.randint(3, high=100, dtype="int32").tolist()
|
||||
all_tokens = first_tokens + second_tokens
|
||||
|
||||
# reference: forward all at once from pos 0
|
||||
self._reset_caches(model)
|
||||
expected = model.forward(Tensor([all_tokens], dtype="int32"), 0).realize()
|
||||
|
||||
# test: forward first chunk, then prefill second chunk at start_pos
|
||||
self._reset_caches(model)
|
||||
model.forward(Tensor([first_tokens], dtype="int32"), 0).realize()
|
||||
token_buf = Tensor.zeros(1, model.max_context, dtype="int32").contiguous().realize()
|
||||
token_buf[0, :len(second_tokens)].assign(Tensor(second_tokens, dtype="int32")).realize()
|
||||
actual = model.prefill_forward(token_buf[:, :model.v_prefill.bind(len(second_tokens))],
|
||||
model.v_start_pos.bind(len(first_tokens))).realize()
|
||||
|
||||
self.assertEqual(expected.item(), actual.item(),
|
||||
f"start_pos mismatch: all-at-once={expected.item()}, split={actual.item()}")
|
||||
|
||||
def test_multi_turn_generate(self):
|
||||
"""Test that multi-turn generate produces same tokens as reference (token-by-token forward)."""
|
||||
model = self.__class__.model
|
||||
num_decode = 3
|
||||
|
||||
# simulate 2 turns of conversation
|
||||
turn1_tokens = Tensor.randint(5, high=100, dtype="int32").tolist()
|
||||
turn2_tokens = Tensor.randint(4, high=100, dtype="int32").tolist()
|
||||
|
||||
# reference: token-by-token forward (no prefill, no JIT)
|
||||
self._reset_caches(model)
|
||||
ref_ids = list(turn1_tokens)
|
||||
# prefill turn 1
|
||||
t = model.forward(Tensor([ref_ids], dtype="int32"), 0)
|
||||
ref_ids.append(int(t.item()))
|
||||
# decode turn 1
|
||||
for _ in range(num_decode - 1):
|
||||
t = model.forward(t, len(ref_ids) - 1)
|
||||
ref_ids.append(int(t.item()))
|
||||
# prefill turn 2
|
||||
start_pos_2 = len(ref_ids)
|
||||
ref_ids += turn2_tokens
|
||||
t = model.forward(Tensor([ref_ids[start_pos_2:]], dtype="int32"), start_pos_2)
|
||||
ref_ids.append(int(t.item()))
|
||||
# decode turn 2
|
||||
for _ in range(num_decode - 1):
|
||||
t = model.forward(t, len(ref_ids) - 1)
|
||||
ref_ids.append(int(t.item()))
|
||||
|
||||
# test: using prefill_forward with symbolic variable
|
||||
self._reset_caches(model)
|
||||
test_ids = list(turn1_tokens)
|
||||
token_buf = Tensor.zeros(1, model.max_context, dtype="int32").contiguous().realize()
|
||||
# prefill turn 1
|
||||
token_buf[0, :len(test_ids)].assign(Tensor(test_ids, dtype="int32")).realize()
|
||||
t = model.prefill_forward(token_buf[:, :model.v_prefill.bind(len(test_ids))], 0)
|
||||
test_ids.append(int(t.item()))
|
||||
# decode turn 1
|
||||
for _ in range(num_decode - 1):
|
||||
t = model.forward(t, len(test_ids) - 1)
|
||||
test_ids.append(int(t.item()))
|
||||
# prefill turn 2
|
||||
start_pos_2 = len(test_ids)
|
||||
test_ids += turn2_tokens
|
||||
new_tokens = test_ids[start_pos_2:]
|
||||
token_buf[0, :len(new_tokens)].assign(Tensor(new_tokens, dtype="int32")).realize()
|
||||
t = model.prefill_forward(token_buf[:, :model.v_prefill.bind(len(new_tokens))],
|
||||
model.v_start_pos.bind(start_pos_2))
|
||||
test_ids.append(int(t.item()))
|
||||
# decode turn 2
|
||||
for _ in range(num_decode - 1):
|
||||
t = model.forward(t, len(test_ids) - 1)
|
||||
test_ids.append(int(t.item()))
|
||||
|
||||
self.assertEqual(ref_ids, test_ids, f"multi-turn mismatch:\nref ={ref_ids}\ntest={test_ids}")
|
||||
|
||||
def test_generate_matches_forward(self):
|
||||
"""Test that generate() (with symbolic prefill + JIT) matches token-by-token forward for 3 turns."""
|
||||
from tinygrad import TinyJit
|
||||
model = self.__class__.model
|
||||
num_decode = 3
|
||||
turns = [Tensor.randint(5, high=100, dtype="int32").tolist() for _ in range(3)]
|
||||
|
||||
# reference: token-by-token forward
|
||||
self._reset_caches(model)
|
||||
ref_ids: list[int] = []
|
||||
for turn_tokens in turns:
|
||||
start_pos = len(ref_ids)
|
||||
ref_ids += turn_tokens
|
||||
t = model.forward(Tensor([ref_ids[start_pos:]], dtype="int32"), start_pos)
|
||||
ref_ids.append(int(t.item()))
|
||||
for _ in range(num_decode - 1):
|
||||
t = model.forward(t, len(ref_ids) - 1)
|
||||
ref_ids.append(int(t.item()))
|
||||
|
||||
# test: using generate() with symbolic prefill
|
||||
self._reset_caches(model)
|
||||
model.forward_jit = TinyJit(model.forward)
|
||||
model.prefill_jit = TinyJit(model.prefill_forward)
|
||||
test_ids: list[int] = []
|
||||
for turn_tokens in turns:
|
||||
start_pos = len(test_ids)
|
||||
test_ids += turn_tokens
|
||||
gen = model.generate(test_ids, start_pos)
|
||||
for _ in range(num_decode): next(gen)
|
||||
|
||||
self.assertEqual(ref_ids, test_ids, f"generate mismatch:\nref ={ref_ids}\ntest={test_ids}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools
|
||||
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools, array
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function
|
||||
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored
|
||||
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
|
||||
|
|
@ -116,8 +116,7 @@ class TransformerBlock:
|
|||
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
@function
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
def _attention_inner(self, x:Tensor, start_pos:int|UOp, mask:Tensor|None) -> Tensor:
|
||||
x_norm = self.attn_norm(x) # (B,T,D)
|
||||
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
|
||||
if self.qk_norm and self.qk_norm != self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
|
|
@ -142,13 +141,18 @@ class TransformerBlock:
|
|||
#k = self.cache_kv[0, :, :, 0:start_pos+T, :]
|
||||
#v = self.cache_kv[1, :, :, 0:start_pos+T, :]
|
||||
|
||||
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(int(start_pos)+1) if T > 1 else None
|
||||
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
|
||||
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
|
||||
attn = self.attn_output(attn)
|
||||
return x + attn
|
||||
|
||||
@function
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
B, T, _ = x.shape
|
||||
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(int(start_pos)+1) if T > 1 else None
|
||||
return self._attention_inner(x, start_pos, mask)
|
||||
|
||||
@function
|
||||
def _feed_forward(self, h: Tensor) -> Tensor:
|
||||
h_norm = self.ffn_norm(h)
|
||||
|
|
@ -167,6 +171,16 @@ class TransformerBlock:
|
|||
self.cache_kv = Tensor.zeros(2, x.shape[0], self.n_kv_heads, self.max_context, self.head_dim, device=x.device).contiguous().realize()
|
||||
return self._feed_forward(self._attention(x, start_pos)).contiguous()
|
||||
|
||||
def prefill(self, x: Tensor, start_pos: int|UOp):
|
||||
"""Like __call__ but without @function wrappers, for JIT-compatible symbolic prefill."""
|
||||
if not hasattr(self, "cache_kv"):
|
||||
self.cache_kv = Tensor.zeros(2, x.shape[0], self.n_kv_heads, self.max_context, self.head_dim, device=x.device).contiguous().realize()
|
||||
if not hasattr(self, "causal_mask"):
|
||||
self.causal_mask = Tensor.full((1, 1, self.max_context, self.max_context), float("-inf"), dtype=x.dtype, device=x.device).triu(1).contiguous().realize()
|
||||
T = x.shape[1]
|
||||
mask = self.causal_mask[:, :, start_pos:start_pos+T, :start_pos+T] if (T > 1 if isinstance(T, int) else T.vmin > 1) else None
|
||||
return TransformerBlock._feed_forward.fxn(self, self._attention_inner(x, start_pos, mask)).contiguous()
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float,
|
||||
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
|
||||
|
|
@ -176,8 +190,10 @@ class Transformer:
|
|||
self.output_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.output = nn.Linear(dim, vocab_size, bias=False)
|
||||
self.max_context = max_context
|
||||
# JIT is used if T=1 and start_pos is a UOp. TODO: make this not needed by including T in the JIT and making start_pos always a UOp
|
||||
self.forward_jit = TinyJit(self.forward)
|
||||
self.prefill_jit = TinyJit(self.prefill_forward)
|
||||
self.v_start_pos = UOp.variable("start_pos", 1, max_context-1)
|
||||
self.v_prefill = UOp.variable("prefill", 2, max_context)
|
||||
|
||||
def forward(self, tokens:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
x = self.token_embd(tokens) # (B, T, D)
|
||||
|
|
@ -186,7 +202,9 @@ class Transformer:
|
|||
return self.output(self.output_norm(x))[:, -1, :].softmax(-1, dtype="float").argmax(-1, keepdim=True)
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor:
|
||||
return (self.forward_jit if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp) else self.forward)(tokens, start_pos)
|
||||
if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp): return self.forward_jit(tokens, start_pos)
|
||||
if getenv("JIT", 1) and tokens.shape[1] != 1 and isinstance(tokens.shape[1], UOp): return self.prefill_jit(tokens, start_pos)
|
||||
return self.forward(tokens, start_pos)
|
||||
|
||||
@staticmethod
|
||||
def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALIZE", 1))) -> tuple[Transformer, dict]:
|
||||
|
|
@ -224,11 +242,33 @@ class Transformer:
|
|||
Tensor.realize(*params)
|
||||
return model, kv
|
||||
|
||||
def prefill_forward(self, tokens:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
"""Like forward but without @function on attention/ffn, for symbolic prefill."""
|
||||
x = self.token_embd(tokens)
|
||||
for block in self.blk: x = block.prefill(x, start_pos)
|
||||
return self.output(self.output_norm(x))[:, -1, :].softmax(-1, dtype="float").argmax(-1, keepdim=True)
|
||||
|
||||
def generate(self, tokens:list[int], start_pos=0):
|
||||
v_start_pos = UOp.variable("start_pos", 1, self.max_context-1)
|
||||
t = Tensor([tokens[start_pos:]], dtype="int32")
|
||||
if not hasattr(self, "_prefill_buf"):
|
||||
self._prefill_buf = Tensor.zeros(1, self.max_context, dtype="int32").contiguous().realize()
|
||||
self._prefill_staging = array.array('i', [0] * self.max_context)
|
||||
# prefill all tokens at once using symbolic variable
|
||||
num_tokens = len(tokens) - start_pos
|
||||
if num_tokens >= 2 and getenv("SYM", 1):
|
||||
# write tokens directly into the buffer via copyin (no scheduled kernels, no cache misses)
|
||||
for i, t in enumerate(tokens[start_pos:]): self._prefill_staging[i] = t
|
||||
self._prefill_buf._buffer().copyin(memoryview(self._prefill_staging))
|
||||
v_sp = self.v_start_pos.bind(start_pos) if start_pos != 0 else start_pos
|
||||
t = self(self._prefill_buf[:, :self.v_prefill.bind(num_tokens)], v_sp)
|
||||
next_id = int(t.item())
|
||||
tokens.append(next_id)
|
||||
start_pos = len(tokens) - 1
|
||||
yield next_id
|
||||
else:
|
||||
t = Tensor([tokens[start_pos:]], dtype="int32")
|
||||
# decode one token at a time
|
||||
while len(tokens) < self.max_context:
|
||||
t = self(t, v_start_pos.bind(start_pos) if getenv("SYM", 1) and start_pos != 0 and t.shape[-1] == 1 else start_pos)
|
||||
t = self(t, self.v_start_pos.bind(start_pos) if getenv("SYM", 1) and start_pos != 0 and t.shape[-1] == 1 else start_pos)
|
||||
next_id = int(t.item())
|
||||
tokens.append(next_id)
|
||||
start_pos = len(tokens) - 1
|
||||
|
|
@ -356,6 +396,19 @@ if __name__ == "__main__":
|
|||
|
||||
# do benchmark
|
||||
if args.benchmark:
|
||||
# prefill benchmark
|
||||
token_buf = Tensor.zeros(1, args.max_context, dtype="int32").contiguous().realize()
|
||||
print("prefill benchmark:")
|
||||
for length in [4, 16, 64, 256, 1024]:
|
||||
if length > args.max_context: break
|
||||
GlobalCounters.reset()
|
||||
with Timing(prefix=f" prefill {length:4d}: ",
|
||||
on_exit=lambda x, length=length: f", {length*1e9/x:6.2f} tok/s, {GlobalCounters.global_mem/x:7.2f} GB/s,"
|
||||
f" {GlobalCounters.global_mem//1000000}/{GlobalCounters.mem_used//1000000} MB"):
|
||||
model.prefill_jit(token_buf[:, :model.v_prefill.bind(length)], 0).realize()
|
||||
|
||||
# generation benchmark
|
||||
print("generation benchmark:")
|
||||
gen = model.generate([0], 0)
|
||||
for _ in range(args.benchmark):
|
||||
GlobalCounters.reset()
|
||||
|
|
|
|||
|
|
@ -4,12 +4,11 @@ from tinygrad.helpers import all_int, dedup, get_contraction
|
|||
from tinygrad.dtype import dtypes, AddrSpace, Invalid
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
def _dim_max(d:sint) -> int: return d if isinstance(d, int) else d.vmax
|
||||
def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
|
||||
# TODO: symbolic shape
|
||||
if not all_int(dims): return dims
|
||||
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
|
||||
while len(dims) > len(max_sizes) or any(_dim_max(d) > m for d,m in zip(dims, max_sizes)):
|
||||
for i,m in enumerate(max_sizes):
|
||||
if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
|
||||
if i < (len(dims)-1) and _dim_max(dims[i]) * _dim_max(dims[i+1]) <= m:
|
||||
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
|
||||
break
|
||||
else: return None
|
||||
|
|
|
|||
|
|
@ -722,9 +722,10 @@ class Tensor(OpMixin):
|
|||
"""
|
||||
if stop is None: stop, start = start, 0
|
||||
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
||||
if start < (dt:=to_dtype(dtype)).min or dt.max < (stop-step): raise ValueError(f"arange [{start}, {stop}) is not representable in dtype {dtype}")
|
||||
if resolve(start < (dt:=to_dtype(dtype)).min, False) or resolve(dt.max < (stop-step), False):
|
||||
raise ValueError(f"arange [{start}, {stop}) is not representable in dtype {dtype}")
|
||||
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
|
||||
if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
|
||||
if resolve((output_len:=ceildiv(stop-start, step)) <= 0, False): return Tensor([], dtype=dtype, **kwargs)
|
||||
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -2566,7 +2567,6 @@ class Tensor(OpMixin):
|
|||
|
||||
@staticmethod
|
||||
def _tri(r:sint, c:sint, diagonal:int=0, device=None, requires_grad:bool|None=None) -> Tensor:
|
||||
assert isinstance(r, int) and isinstance(c, int), f"does not support symbolic, getting {r=}, {c=}"
|
||||
return (Tensor.arange(r, device=device).unsqueeze(-1) + diagonal <= Tensor.arange(c, device=device)).requires_grad_(requires_grad)
|
||||
|
||||
def triu(self, diagonal:int=0) -> Tensor:
|
||||
|
|
@ -3284,8 +3284,8 @@ class Tensor(OpMixin):
|
|||
print(q.scaled_dot_product_attention(k, v).numpy())
|
||||
```
|
||||
"""
|
||||
# NOTE: it also works when `key` and `value` have symbolic shape.
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
# NOTE: only n_heads (dim -3) and head_dim (dim -1) must be concrete; batch and sequence length dims can be symbolic.
|
||||
assert all_int(self.shape[-1:]) and all_int(self.shape[-3:-2]), f"does not support symbolic shape {self.shape}"
|
||||
|
||||
if getenv("FLASH_ATTENTION"):
|
||||
from extra.thunder.tiny.fa import flash_attention
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue