mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
update mask in scaled_dot_product_attention (#8674)
built is_causal mask with ones_like and start with boolean, and reversed the mask -inf order
This commit is contained in:
parent
5842ee56c6
commit
beba490ba8
2 changed files with 9 additions and 7 deletions
|
|
@ -1000,7 +1000,7 @@ class TestSchedule(unittest.TestCase):
|
|||
def test_scaled_dot_product_attention_causal_fusion(self):
|
||||
x, y, z = (Tensor.empty(32, 8, 16, 16) for _ in range(3))
|
||||
out = Tensor.scaled_dot_product_attention(x, y, z, is_causal=True)
|
||||
check_schedule(out, 6)
|
||||
check_schedule(out, 5)
|
||||
|
||||
def test_adam_step_fusion(self):
|
||||
with Tensor.train():
|
||||
|
|
|
|||
|
|
@ -3565,8 +3565,7 @@ class Tensor(SimpleMathTrait):
|
|||
if num_classes == -1: num_classes = (self.max()+1).item()
|
||||
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
|
||||
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
|
||||
dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
||||
"""
|
||||
Computes scaled dot-product attention.
|
||||
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
|
||||
|
|
@ -3583,12 +3582,15 @@ class Tensor(SimpleMathTrait):
|
|||
"""
|
||||
# NOTE: it also works when `key` and `value` have symbolic shape.
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
|
||||
# handle attention mask
|
||||
if is_causal:
|
||||
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
|
||||
attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
|
||||
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
|
||||
qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
|
||||
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
|
||||
attn_mask = qk.ones_like(requires_grad=False, device=self.device, dtype=dtypes.bool).tril()
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
|
||||
qk = qk + attn_mask
|
||||
return qk.softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
|
||||
|
||||
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
|
||||
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue