update mask in scaled_dot_product_attention (#8674)

built is_causal mask with ones_like and start with boolean, and reversed the mask -inf order
This commit is contained in:
chenyu 2025-01-19 15:19:23 -05:00 committed by GitHub
commit beba490ba8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 7 deletions

View file

@ -1000,7 +1000,7 @@ class TestSchedule(unittest.TestCase):
def test_scaled_dot_product_attention_causal_fusion(self):
x, y, z = (Tensor.empty(32, 8, 16, 16) for _ in range(3))
out = Tensor.scaled_dot_product_attention(x, y, z, is_causal=True)
check_schedule(out, 6)
check_schedule(out, 5)
def test_adam_step_fusion(self):
with Tensor.train():

View file

@ -3565,8 +3565,7 @@ class Tensor(SimpleMathTrait):
if num_classes == -1: num_classes = (self.max()+1).item()
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
"""
Computes scaled dot-product attention.
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
@ -3583,12 +3582,15 @@ class Tensor(SimpleMathTrait):
"""
# NOTE: it also works when `key` and `value` have symbolic shape.
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
# handle attention mask
if is_causal:
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
attn_mask = qk.ones_like(requires_grad=False, device=self.device, dtype=dtypes.bool).tril()
if attn_mask is not None:
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
qk = qk + attn_mask
return qk.softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")