mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
47f18f4d60
commit
38fe84d92b
4 changed files with 53 additions and 42 deletions
|
|
@ -1143,9 +1143,9 @@ class TestOps(unittest.TestCase):
|
|||
self.helper_test_exception([], lambda: tor[tb,:,:,:,:].sum().backward(), lambda: ten.gather(ta, dim=0).sum().backward(), expected=(IndexError, RuntimeError)) # torch raises IndexError, Tensor raises RuntimeError
|
||||
|
||||
def test_scaled_product_attention(self):
|
||||
helper_test_op([(32,8,128,64), (32,8,128,64), (32,8,128,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z))
|
||||
helper_test_op([(32,8,128,64), (32,8,128,64), (32,8,128,64), (32,8,128,128)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m))
|
||||
helper_test_op([(32,8,128,64), (32,8,128,64), (32,8,128,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True))
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z))
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m))
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True))
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.random.seed(1337)
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class TestUOps(unittest.TestCase):
|
|||
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b)
|
||||
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
|
||||
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b))
|
||||
# CMPLT and MOD aren't tested
|
||||
# MOD isn't tested
|
||||
|
||||
# doesn't work in LLVM
|
||||
#def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b, dtypes.int32)
|
||||
|
|
|
|||
|
|
@ -195,8 +195,21 @@ class LazyBuffer:
|
|||
assert not arg[1] or self.dtype.itemsize == arg[0].itemsize, "can't bitcast mismatched dtype itemsizes"
|
||||
return elementwise_op(UnaryOps.CAST, self, arg=arg) if self.dtype != arg[0] else self
|
||||
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
|
||||
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
|
||||
def ternary_op(self:LazyBuffer, op:TernaryOps, y: LazyBuffer, z:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y, z)
|
||||
def binary_op(self:LazyBuffer, op:BinaryOps, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(op, self, y)
|
||||
def ternary_op(self:LazyBuffer, op:TernaryOps, y:Union[LazyBuffer, float, int], z:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(op, self, y, z)
|
||||
|
||||
def __add__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.ADD, self, y)
|
||||
def __radd__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.ADD, y, self)
|
||||
def __mul__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.MUL, self, y)
|
||||
def __rmul__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.MUL, y, self)
|
||||
def __truediv__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.DIV, self, y)
|
||||
def __rtruediv__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.DIV, y, self)
|
||||
def __sub__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.SUB, self, y)
|
||||
def __rsub__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.SUB, y, self)
|
||||
def __lt__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.CMPLT, self, y)
|
||||
def __gt__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.CMPLT, y, self)
|
||||
def __neg__(self) -> LazyBuffer: return 0.0-self
|
||||
|
||||
def contiguous(self:LazyBuffer) -> LazyBuffer:
|
||||
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
|
||||
return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype)
|
||||
|
|
@ -304,7 +317,11 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
|
|||
new_srcs.append(x)
|
||||
return tuple(new_srcs)
|
||||
|
||||
def elementwise_op(op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
||||
def elementwise_op(op:Union[UnaryOps, BinaryOps, TernaryOps], *_srcs:Union[LazyBuffer, float, int], arg:Optional[Any]=None) -> LazyBuffer:
|
||||
# make them all LazyBuffers
|
||||
first_src = [x for x in _srcs if isinstance(x, LazyBuffer)][0]
|
||||
srcs:Tuple[LazyBuffer, ...] = tuple(x if isinstance(x, LazyBuffer) else first_src.const_like(x) for x in _srcs)
|
||||
|
||||
# if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops
|
||||
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Tuple, Optional
|
||||
from tinygrad.helpers import argsort, ShapeType
|
||||
from tinygrad.helpers import argsort, ShapeType, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.tensor import Function
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
|
@ -11,50 +11,49 @@ class Contiguous(Function):
|
|||
|
||||
class Cast(Function):
|
||||
__slots__ = "input_dtype", "bitcast"
|
||||
def forward(self, x, dtype, bitcast=False):
|
||||
def forward(self, x:LazyBuffer, dtype:DType, bitcast=False):
|
||||
self.input_dtype, self.bitcast = x.dtype, bitcast
|
||||
return x.cast((dtype, bitcast))
|
||||
def backward(self, grad_output):
|
||||
def backward(self, grad_output:LazyBuffer):
|
||||
return grad_output.cast((self.input_dtype, self.bitcast))
|
||||
|
||||
# ************* unary ops *************
|
||||
|
||||
class Sin(Function):
|
||||
__slots__ = "x"
|
||||
def forward(self, x: LazyBuffer) -> LazyBuffer:
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.unary_op(UnaryOps.SIN)
|
||||
def backward(self, grad: LazyBuffer) -> LazyBuffer:
|
||||
return self.x.const_like(math.pi / 2).binary_op(BinaryOps.SUB, self.x).unary_op(UnaryOps.SIN).binary_op(BinaryOps.MUL, grad)
|
||||
def backward(self, grad:LazyBuffer) -> LazyBuffer:
|
||||
return ((math.pi / 2) - self.x).unary_op(UnaryOps.SIN) * grad
|
||||
|
||||
# NOTE: maximum(x, 0) behaves differently where x=0
|
||||
class Relu(Function):
|
||||
__slots__ = "ret"
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.binary_op(BinaryOps.MAX, x.const_like(0))
|
||||
self.ret = x.binary_op(BinaryOps.MAX, 0)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
mask = self.ret.const_like(0).binary_op(BinaryOps.CMPLT, self.ret)
|
||||
return mask.binary_op(BinaryOps.MUL, grad_output)
|
||||
return (0 < self.ret) * grad_output
|
||||
|
||||
class Log(Function):
|
||||
__slots__ = "x"
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, x.const_like(math.log(2)))
|
||||
return x.unary_op(UnaryOps.LOG2) * math.log(2)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.binary_op(BinaryOps.DIV, self.x)
|
||||
return grad_output / self.x
|
||||
|
||||
class Exp(Function):
|
||||
__slots__ = "ret"
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.binary_op(BinaryOps.MUL, x.const_like(1/math.log(2))).unary_op(UnaryOps.EXP2)
|
||||
self.ret = (x * (1/math.log(2))).unary_op(UnaryOps.EXP2)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return self.ret.binary_op(BinaryOps.MUL, grad_output)
|
||||
return self.ret * grad_output
|
||||
|
||||
class Sqrt(Function):
|
||||
__slots__ = "ret"
|
||||
|
|
@ -63,7 +62,7 @@ class Sqrt(Function):
|
|||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.binary_op(BinaryOps.DIV, self.ret.binary_op(BinaryOps.MUL, self.ret.const_like(2)))
|
||||
return grad_output / (self.ret * 2)
|
||||
|
||||
# NOTE: the implicit derivative of sigmoid is not stable
|
||||
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
|
||||
|
|
@ -71,11 +70,11 @@ class Sqrt(Function):
|
|||
class Sigmoid(Function):
|
||||
__slots__ = "ret"
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.const_like(1).binary_op(BinaryOps.DIV, x.const_like(1).binary_op(BinaryOps.ADD, x.binary_op(BinaryOps.MUL, x.const_like(-1/math.log(2))).unary_op(UnaryOps.EXP2)))
|
||||
self.ret = 1 / (1 + (x * (-1/math.log(2))).unary_op(UnaryOps.EXP2))
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return self.ret.binary_op(BinaryOps.MUL, self.ret.const_like(1).binary_op(BinaryOps.SUB, self.ret)).binary_op(BinaryOps.MUL, grad_output)
|
||||
return (self.ret * (1 - self.ret)) * grad_output
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
|
|
@ -96,24 +95,19 @@ class Max(Function):
|
|||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = self.x.const_like(1).binary_op(BinaryOps.SUB, self.x.binary_op(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
|
||||
|
||||
# sum of locations, averaged
|
||||
max_is_1s = 1.0 - (self.x < self.ret.expand(self.x.shape))
|
||||
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
|
||||
max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div)
|
||||
|
||||
grad_output_expanded = grad_output.expand(self.x.shape)
|
||||
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
|
||||
return (max_is_1s / div) * grad_output.expand(self.x.shape)
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
class Less(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.binary_op(BinaryOps.CMPLT, y)
|
||||
return x < y
|
||||
|
||||
class Add(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.binary_op(BinaryOps.ADD, y)
|
||||
return x + y
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output if self.needs_input_grad[0] else None, \
|
||||
|
|
@ -121,31 +115,31 @@ class Add(Function):
|
|||
|
||||
class Sub(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.binary_op(BinaryOps.SUB, y)
|
||||
return x - y
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output if self.needs_input_grad[0] else None, \
|
||||
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output) if self.needs_input_grad[1] else None
|
||||
-grad_output if self.needs_input_grad[1] else None
|
||||
|
||||
class Mul(Function):
|
||||
__slots__ = 'x', 'y'
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x, self.y = x, y
|
||||
return x.binary_op(BinaryOps.MUL, y)
|
||||
return x * y
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return self.y.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
|
||||
self.x.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
|
||||
return self.y * grad_output if self.needs_input_grad[0] else None, \
|
||||
self.x * grad_output if self.needs_input_grad[1] else None
|
||||
|
||||
class Div(Function):
|
||||
__slots__ = 'x', 'y'
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x, self.y = x, y
|
||||
return x.binary_op(BinaryOps.DIV, y)
|
||||
return x / y
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output.binary_op(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
|
||||
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output).binary_op(BinaryOps.MUL, self.x).binary_op(BinaryOps.DIV, self.y.binary_op(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None
|
||||
return grad_output / self.y if self.needs_input_grad[0] else None, \
|
||||
(-grad_output * self.x) / (self.y * self.y) if self.needs_input_grad[1] else None
|
||||
|
||||
# ************* ternary ops *************
|
||||
|
||||
|
|
@ -157,8 +151,8 @@ class Where(Function):
|
|||
|
||||
def backward(self, grad_output:LazyBuffer):
|
||||
return None, \
|
||||
self.x.ternary_op(TernaryOps.WHERE, grad_output, self.x.const_like(0)) if self.needs_input_grad[1] else None, \
|
||||
self.x.ternary_op(TernaryOps.WHERE, self.x.const_like(0), grad_output) if self.needs_input_grad[2] else None
|
||||
self.x.ternary_op(TernaryOps.WHERE, grad_output, 0) if self.needs_input_grad[1] else None, \
|
||||
self.x.ternary_op(TernaryOps.WHERE, 0, grad_output) if self.needs_input_grad[2] else None
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue