mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
hotfix llama start_pos vmax is max_context-1 (#10659)
* hotfix llama start_pos vmax is max_context-1 fixed `IGNORE_OOB=0 python3 examples/llama3.py --size 1B --benchmark --temperature 0` * hotfix: multitensor transformer test tests kv cache --------- Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
parent
5eb6e1e65a
commit
4a6d84c4c3
2 changed files with 8 additions and 6 deletions
|
|
@ -191,7 +191,7 @@ class Transformer:
|
|||
def __call__(self, tokens:Tensor, start_pos:int, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
|
||||
# TODO: better way to handle the first call v.s. the rest?
|
||||
if tokens.shape[0:2] == (1,1) and self.forward_jit is not None and start_pos != 0:
|
||||
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context-1).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
|
||||
# *** helpers ***
|
||||
|
|
|
|||
|
|
@ -1180,8 +1180,8 @@ class TestMultiTransformer(unittest.TestCase):
|
|||
|
||||
from extra.models.llama import Transformer
|
||||
args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024, "hidden_dim": 64}
|
||||
real_model = Transformer(**args, jit=False)
|
||||
shard_model = Transformer(**args, jit=False)
|
||||
real_model = Transformer(**args)
|
||||
shard_model = Transformer(**args)
|
||||
|
||||
# copy state
|
||||
nn.state.load_state_dict(shard_model, nn.state.get_state_dict(real_model))
|
||||
|
|
@ -1198,9 +1198,11 @@ class TestMultiTransformer(unittest.TestCase):
|
|||
else: v.shard_(device, axis=None)
|
||||
|
||||
last_tok = 0
|
||||
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), 0)
|
||||
shard_tok = shard_model(Tensor([[last_tok]], device=device), 0)
|
||||
self.assertEqual(real_tok.item(), shard_tok.item())
|
||||
for i in range(10):
|
||||
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i)
|
||||
shard_tok = shard_model(Tensor([[last_tok]], device=device), i)
|
||||
last_tok = real_tok.item()
|
||||
self.assertEqual(last_tok, shard_tok.item(), f"issue at token {i}")
|
||||
|
||||
@unittest.skip("super slow")
|
||||
def test_llama1b_full(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue