Revert "remove TransformerBlock contiguous in llama (#10104)" (#10108)

This reverts commit b8d07dcc54.
This commit is contained in:
chenyu 2025-04-29 15:28:38 -04:00 committed by GitHub
commit 573bbb9746
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -107,7 +107,7 @@ class TransformerBlock:
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
return h + self.feed_forward(self.ffn_norm(h))
return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
# standard openai sampling
def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):