mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
simpler llama - don't shrink twice (#1981)
This commit is contained in:
parent
972d9ea215
commit
da2b3e55f4
1 changed files with 3 additions and 8 deletions
|
|
@ -124,12 +124,6 @@ class TransformerBlock:
|
|||
cache_k.lazydata.var_vals[pos] = start_pos
|
||||
cache_v.lazydata.var_vals[pos] = start_pos
|
||||
|
||||
# get only the part of freqs_cis that we are using.
|
||||
freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (pos, pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4])))
|
||||
freqs_cis.lazydata.var_vals[pos] = start_pos
|
||||
else:
|
||||
freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (start_pos, start_pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4])))
|
||||
|
||||
output, cache_k, cache_v = self.attention(self.attention_norm(x), cache_k, cache_v, start_pos, freqs_cis, mask, jit_ctx=jit_ctx)
|
||||
h = x + output
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).realize(), cache_k.realize(), cache_v.realize()
|
||||
|
|
@ -157,11 +151,12 @@ class Transformer:
|
|||
_bsz, seqlen = tokens.shape
|
||||
if seqlen == 1 and JIT:
|
||||
pos = Variable("pos", 1, 1024)
|
||||
# get only the part of freqs_cis that we are using.
|
||||
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (pos, pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
|
||||
freqs_cis.lazydata.var_vals[pos] = start_pos
|
||||
h = self.tok_embeddings_jitted(tokens)
|
||||
for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers_jitted, self.kv_caches)):
|
||||
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=self.freqs_cis, mask=None, jit_ctx={pos: start_pos})
|
||||
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=freqs_cis, mask=None, jit_ctx={pos: start_pos})
|
||||
# TODO: move the kv cache into Attention, pre-allocate the cache and instead of cat, update the cache in-place
|
||||
self.kv_caches[i] = (cache_k, cache_v)
|
||||
return self.postprocess_jitted(h, temperature)
|
||||
|
|
@ -174,7 +169,7 @@ class Transformer:
|
|||
if cache_k is not None and start_pos > 0:
|
||||
cache_k = cache_k.reshape(cache_k.shape[0], start_pos, cache_k.shape[2], cache_k.shape[3])
|
||||
cache_v = cache_v.reshape(cache_v.shape[0], start_pos, cache_v.shape[2], cache_v.shape[3])
|
||||
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=self.freqs_cis, mask=mask)
|
||||
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask)
|
||||
self.kv_caches[i] = (cache_k, cache_v)
|
||||
return self.postprocess(h, temperature)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue