simpler llama - don't shrink twice (#1981)

This commit is contained in:
chenyu 2023-10-05 17:31:46 -04:00 committed by GitHub
commit da2b3e55f4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -124,12 +124,6 @@ class TransformerBlock:
cache_k.lazydata.var_vals[pos] = start_pos
cache_v.lazydata.var_vals[pos] = start_pos
# get only the part of freqs_cis that we are using.
freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (pos, pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4])))
freqs_cis.lazydata.var_vals[pos] = start_pos
else:
freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (start_pos, start_pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4])))
output, cache_k, cache_v = self.attention(self.attention_norm(x), cache_k, cache_v, start_pos, freqs_cis, mask, jit_ctx=jit_ctx)
h = x + output
return (h + self.feed_forward(self.ffn_norm(h))).realize(), cache_k.realize(), cache_v.realize()
@ -157,11 +151,12 @@ class Transformer:
_bsz, seqlen = tokens.shape
if seqlen == 1 and JIT:
pos = Variable("pos", 1, 1024)
# get only the part of freqs_cis that we are using.
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (pos, pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
freqs_cis.lazydata.var_vals[pos] = start_pos
h = self.tok_embeddings_jitted(tokens)
for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers_jitted, self.kv_caches)):
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=self.freqs_cis, mask=None, jit_ctx={pos: start_pos})
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=freqs_cis, mask=None, jit_ctx={pos: start_pos})
# TODO: move the kv cache into Attention, pre-allocate the cache and instead of cat, update the cache in-place
self.kv_caches[i] = (cache_k, cache_v)
return self.postprocess_jitted(h, temperature)
@ -174,7 +169,7 @@ class Transformer:
if cache_k is not None and start_pos > 0:
cache_k = cache_k.reshape(cache_k.shape[0], start_pos, cache_k.shape[2], cache_k.shape[3])
cache_v = cache_v.reshape(cache_v.shape[0], start_pos, cache_v.shape[2], cache_v.shape[3])
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=self.freqs_cis, mask=mask)
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask)
self.kv_caches[i] = (cache_k, cache_v)
return self.postprocess(h, temperature)