llama3: correctly shard wqkv (#14978)

This commit is contained in:
wozeparrot 2026-02-23 23:57:10 -08:00 committed by GitHub
commit 8d9545e09e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 5 additions and 1 deletions

View file

@ -1359,6 +1359,7 @@ def train_llama3():
elif '.attention.wq' in k: v.shard_(device, axis=0)
elif '.attention.wk' in k: v.shard_(device, axis=0)
elif '.attention.wv' in k: v.shard_(device, axis=0)
elif '.attention.wqkv' in k: v.shard_(device, axis=0)
elif '.attention.wo' in k: v.shard_(device, axis=1)
elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
elif '.feed_forward.w2.' in k: v.shard_(device, axis=1)

View file

@ -56,7 +56,10 @@ class Attention:
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]=None) -> Tensor:
if getenv("WQKV"):
xqkv = self.wqkv(x)
xq, xk, xv = xqkv.split([self.n_heads * self.head_dim, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim], dim=2)
xqkv = xqkv.reshape(xqkv.shape[0], xqkv.shape[1], self.n_kv_heads, self.n_rep + 2, self.head_dim)
xq = xqkv[:, :, :, :self.n_rep].reshape(xqkv.shape[0], xqkv.shape[1], -1)
xk = xqkv[:, :, :, self.n_rep:self.n_rep+1].reshape(xqkv.shape[0], xqkv.shape[1], -1)
xv = xqkv[:, :, :, self.n_rep+1:self.n_rep+2].reshape(xqkv.shape[0], xqkv.shape[1], -1)
else:
xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x)