hotfix llama start_pos vmax is max_context-1 (#10659)

* hotfix llama start_pos vmax is max_context-1

fixed `IGNORE_OOB=0 python3 examples/llama3.py --size 1B --benchmark --temperature 0`

* hotfix: multitensor transformer test tests kv cache

---------

Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
chenyu 2025-06-06 00:41:25 -04:00 committed by GitHub
commit 4a6d84c4c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 6 deletions

View file

@ -191,7 +191,7 @@ class Transformer:
def __call__(self, tokens:Tensor, start_pos:int, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
# TODO: better way to handle the first call v.s. the rest?
if tokens.shape[0:2] == (1,1) and self.forward_jit is not None and start_pos != 0:
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context-1).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
# *** helpers ***

View file

@ -1180,8 +1180,8 @@ class TestMultiTransformer(unittest.TestCase):
from extra.models.llama import Transformer
args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024, "hidden_dim": 64}
real_model = Transformer(**args, jit=False)
shard_model = Transformer(**args, jit=False)
real_model = Transformer(**args)
shard_model = Transformer(**args)
# copy state
nn.state.load_state_dict(shard_model, nn.state.get_state_dict(real_model))
@ -1198,9 +1198,11 @@ class TestMultiTransformer(unittest.TestCase):
else: v.shard_(device, axis=None)
last_tok = 0
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), 0)
shard_tok = shard_model(Tensor([[last_tok]], device=device), 0)
self.assertEqual(real_tok.item(), shard_tok.item())
for i in range(10):
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i)
shard_tok = shard_model(Tensor([[last_tok]], device=device), i)
last_tok = real_tok.item()
self.assertEqual(last_tok, shard_tok.item(), f"issue at token {i}")
@unittest.skip("super slow")
def test_llama1b_full(self):