mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama3: correctly shard wqkv (#14978)
This commit is contained in:
parent
a36a26d4ed
commit
8d9545e09e
2 changed files with 5 additions and 1 deletions
|
|
@ -1359,6 +1359,7 @@ def train_llama3():
|
|||
elif '.attention.wq' in k: v.shard_(device, axis=0)
|
||||
elif '.attention.wk' in k: v.shard_(device, axis=0)
|
||||
elif '.attention.wv' in k: v.shard_(device, axis=0)
|
||||
elif '.attention.wqkv' in k: v.shard_(device, axis=0)
|
||||
elif '.attention.wo' in k: v.shard_(device, axis=1)
|
||||
elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
|
||||
elif '.feed_forward.w2.' in k: v.shard_(device, axis=1)
|
||||
|
|
|
|||
|
|
@ -56,7 +56,10 @@ class Attention:
|
|||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]=None) -> Tensor:
|
||||
if getenv("WQKV"):
|
||||
xqkv = self.wqkv(x)
|
||||
xq, xk, xv = xqkv.split([self.n_heads * self.head_dim, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim], dim=2)
|
||||
xqkv = xqkv.reshape(xqkv.shape[0], xqkv.shape[1], self.n_kv_heads, self.n_rep + 2, self.head_dim)
|
||||
xq = xqkv[:, :, :, :self.n_rep].reshape(xqkv.shape[0], xqkv.shape[1], -1)
|
||||
xk = xqkv[:, :, :, self.n_rep:self.n_rep+1].reshape(xqkv.shape[0], xqkv.shape[1], -1)
|
||||
xv = xqkv[:, :, :, self.n_rep+1:self.n_rep+2].reshape(xqkv.shape[0], xqkv.shape[1], -1)
|
||||
else:
|
||||
xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue