mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llm: fix nan in kvcache (#15552)
This commit is contained in:
parent
3af25ccdb4
commit
5181c8e23a
2 changed files with 24 additions and 4 deletions
|
|
@ -30,10 +30,10 @@ class TestTransformerGenerate(unittest.TestCase):
|
|||
gen = model.generate(tokens)
|
||||
next(gen)
|
||||
|
||||
# should only process tokens[7:] = [10, 11, 12] since first 7 are cached
|
||||
# should process tokens[6:] = [42, 10, 11, 12] since first 6 have cached k/v
|
||||
toks_shape = captured_inputs[0][0][-1]
|
||||
self.assertEqual(toks_shape.val if isinstance(toks_shape, UOp) else toks_shape, 3)
|
||||
self.assertEqual(captured_inputs[0][1], 7)
|
||||
self.assertEqual(toks_shape.val if isinstance(toks_shape, UOp) else toks_shape, 4)
|
||||
self.assertEqual(captured_inputs[0][1], 6)
|
||||
|
||||
def test_kv_cache_invalidation(self):
|
||||
"""Test that generate invalidates the KV cache when tokens diverge from the cached prefix."""
|
||||
|
|
@ -106,6 +106,26 @@ class TestTransformerGenerate(unittest.TestCase):
|
|||
# 4 tokens, chunk_size=4 -> 1 prefill chunk
|
||||
self.assertEqual(get_prefill_flags(list(range(4)), 4), [True, False, False])
|
||||
|
||||
def test_kv_cache_resume_matches_fresh(self):
|
||||
model = Transformer(TEST_CONFIG)
|
||||
|
||||
# generate 2 tokens, then abandon
|
||||
prompt = list(range(1, 6))
|
||||
gen = model.generate(list(prompt))
|
||||
out1, out2 = next(gen), next(gen)
|
||||
|
||||
# resume with conversation history + new user tokens appended
|
||||
extended = prompt + [out1, out2, 10, 11, 12]
|
||||
gen = model.generate(list(extended))
|
||||
resumed_out = [next(gen) for _ in range(3)]
|
||||
|
||||
# compare against fresh generation (no cache) of the same prompt
|
||||
model._cached_tokens = []
|
||||
gen = model.generate(list(extended))
|
||||
fresh_out = [next(gen) for _ in range(3)]
|
||||
|
||||
self.assertEqual(fresh_out, resumed_out)
|
||||
|
||||
def test_temperature_zero_is_greedy(self):
|
||||
"""Temperature 0 (or near 0) should produce deterministic output."""
|
||||
model = Transformer(TEST_CONFIG)
|
||||
|
|
|
|||
|
|
@ -271,7 +271,7 @@ class Transformer:
|
|||
# chunked prefill: keep processing until all prompt tokens are consumed
|
||||
if start_pos < len(tokens): continue
|
||||
tokens.append(int(out.item()))
|
||||
self._cached_tokens = tokens[:]
|
||||
self._cached_tokens = tokens[:-1]
|
||||
yield tokens[-1]
|
||||
|
||||
models = {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue