mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
dropout and scaled_dot_product_attention to mixin (#16707)
This commit is contained in:
parent
33b635d23a
commit
0138480910
2 changed files with 54 additions and 56 deletions
|
|
@ -1,8 +1,8 @@
|
|||
from __future__ import annotations
|
||||
import math
|
||||
from typing import Self, cast
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, to_dtype
|
||||
from tinygrad.helpers import all_int, argfix, ceildiv, prod
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, least_upper_dtype, to_dtype
|
||||
from tinygrad.helpers import all_int, argfix, ceildiv, prod, TRAINING
|
||||
from tinygrad.mixin import OpMixin
|
||||
from tinygrad.device import canonicalize_device
|
||||
|
||||
|
|
@ -273,3 +273,54 @@ class RandMixin(OpMixin):
|
|||
# Efraimidis-Spirakis
|
||||
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
|
||||
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
|
||||
|
||||
def dropout(self, p=0.5) -> Self:
|
||||
"""
|
||||
Applies dropout to `self`.
|
||||
|
||||
NOTE: dropout is only applied when `TRAINING` is set (e.g. inside `Context(TRAINING=1)`).
|
||||
|
||||
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 2)
|
||||
with Context(TRAINING=1):
|
||||
print(t.dropout().numpy())
|
||||
```
|
||||
"""
|
||||
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
|
||||
if not TRAINING or p == 0: return self
|
||||
if p == 1: return self.const_like(0)
|
||||
return (self.rand_like(dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
|
||||
|
||||
def scaled_dot_product_attention(self, key:Self, value:Self, attn_mask:Self|None=None, dropout_p:float=0.0,
|
||||
is_causal:bool=False, enable_gqa:bool=False) -> Self:
|
||||
"""
|
||||
Computes scaled dot-product attention.
|
||||
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1706.03762v7
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
q = Tensor.randn(2, 4, 8)
|
||||
k = Tensor.randn(2, 4, 8)
|
||||
v = Tensor.randn(2, 4, 8)
|
||||
print(q.scaled_dot_product_attention(k, v).numpy())
|
||||
```
|
||||
"""
|
||||
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
if enable_gqa:
|
||||
key = key.repeat_interleave(int(self.shape[-3] // key.shape[-3]), dim=-3)
|
||||
value = value.repeat_interleave(int(self.shape[-3] // value.shape[-3]), dim=-3)
|
||||
|
||||
q = self
|
||||
qk = q.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(q.dtype, key.dtype, dtypes.float32)) / math.sqrt(q.shape[-1])
|
||||
# handle attention mask
|
||||
if is_causal:
|
||||
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
|
||||
attn_mask = qk.const_like(1).cast(dtypes.bool).tril()
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
|
||||
qk = qk + attn_mask
|
||||
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref
|
||||
from typing import Any, Callable, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
|
||||
if TYPE_CHECKING: import numpy
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, to_dtype
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
|
||||
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, fully_flatten, ceildiv, fetch, flat_to_grouped
|
||||
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
|
||||
|
|
@ -785,59 +785,6 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
|||
fn = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(frame_pos.src[0], *[UOp.const(dtypes.int, s) for s in shape]), arg="encdec")
|
||||
return Tensor(out.uop.after(fn.call(*[s.uop for s in srcs], frame_pos)))
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
def dropout(self, p=0.5) -> Tensor:
|
||||
"""
|
||||
Applies dropout to `self`.
|
||||
|
||||
NOTE: dropout is only applied when `Tensor.training` is `True`.
|
||||
|
||||
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 2)
|
||||
with Context(TRAINING=1):
|
||||
print(t.dropout().numpy())
|
||||
```
|
||||
"""
|
||||
if not 0 <= p <= 1: raise ValueError(f"{p=} is out of range [0, 1]")
|
||||
if not Tensor.training or p == 0: return self
|
||||
if p == 1: return self.const_like(0)
|
||||
return (Tensor.rand_like(self, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
|
||||
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,
|
||||
is_causal:bool=False, enable_gqa:bool=False) -> Tensor:
|
||||
"""
|
||||
Computes scaled dot-product attention.
|
||||
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1706.03762v7
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
q = Tensor.randn(2, 4, 8)
|
||||
k = Tensor.randn(2, 4, 8)
|
||||
v = Tensor.randn(2, 4, 8)
|
||||
print(q.scaled_dot_product_attention(k, v).numpy())
|
||||
```
|
||||
"""
|
||||
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
if enable_gqa:
|
||||
key = key.repeat_interleave(int(self.shape[-3] // key.shape[-3]), dim=-3)
|
||||
value = value.repeat_interleave(int(self.shape[-3] // value.shape[-3]), dim=-3)
|
||||
|
||||
q = self
|
||||
qk = q.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(q.dtype, key.dtype, dtypes.float32)) / math.sqrt(q.shape[-1])
|
||||
# handle attention mask
|
||||
if is_causal:
|
||||
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
|
||||
attn_mask = qk.const_like(1).cast(dtypes.bool).tril()
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
|
||||
qk = qk + attn_mask
|
||||
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
|
||||
|
||||
# ***** cast ops *****
|
||||
|
||||
def bitcast(self, dtype:DTypeLike) -> Tensor:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue