cleanup mlops (#1521)

* cleanup mlops

* that line belongs there
This commit is contained in:
George Hotz 2023-08-10 19:53:28 -07:00 committed by GitHub
commit 38fe84d92b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 42 deletions

View file

@ -1143,9 +1143,9 @@ class TestOps(unittest.TestCase):
self.helper_test_exception([], lambda: tor[tb,:,:,:,:].sum().backward(), lambda: ten.gather(ta, dim=0).sum().backward(), expected=(IndexError, RuntimeError)) # torch raises IndexError, Tensor raises RuntimeError
def test_scaled_product_attention(self):
helper_test_op([(32,8,128,64), (32,8,128,64), (32,8,128,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z))
helper_test_op([(32,8,128,64), (32,8,128,64), (32,8,128,64), (32,8,128,128)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m))
helper_test_op([(32,8,128,64), (32,8,128,64), (32,8,128,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True))
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z))
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m))
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True))
if __name__ == '__main__':
np.random.seed(1337)

View file

@ -62,7 +62,7 @@ class TestUOps(unittest.TestCase):
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b)
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b))
# CMPLT and MOD aren't tested
# MOD isn't tested
# doesn't work in LLVM
#def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b, dtypes.int32)

View file

@ -195,8 +195,21 @@ class LazyBuffer:
assert not arg[1] or self.dtype.itemsize == arg[0].itemsize, "can't bitcast mismatched dtype itemsizes"
return elementwise_op(UnaryOps.CAST, self, arg=arg) if self.dtype != arg[0] else self
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
def ternary_op(self:LazyBuffer, op:TernaryOps, y: LazyBuffer, z:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y, z)
def binary_op(self:LazyBuffer, op:BinaryOps, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(op, self, y)
def ternary_op(self:LazyBuffer, op:TernaryOps, y:Union[LazyBuffer, float, int], z:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(op, self, y, z)
def __add__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.ADD, self, y)
def __radd__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.ADD, y, self)
def __mul__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.MUL, self, y)
def __rmul__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.MUL, y, self)
def __truediv__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.DIV, self, y)
def __rtruediv__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.DIV, y, self)
def __sub__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.SUB, self, y)
def __rsub__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.SUB, y, self)
def __lt__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.CMPLT, self, y)
def __gt__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.CMPLT, y, self)
def __neg__(self) -> LazyBuffer: return 0.0-self
def contiguous(self:LazyBuffer) -> LazyBuffer:
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype)
@ -304,7 +317,11 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
new_srcs.append(x)
return tuple(new_srcs)
def elementwise_op(op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
def elementwise_op(op:Union[UnaryOps, BinaryOps, TernaryOps], *_srcs:Union[LazyBuffer, float, int], arg:Optional[Any]=None) -> LazyBuffer:
# make them all LazyBuffers
first_src = [x for x in _srcs if isinstance(x, LazyBuffer)][0]
srcs:Tuple[LazyBuffer, ...] = tuple(x if isinstance(x, LazyBuffer) else first_src.const_like(x) for x in _srcs)
# if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)

View file

@ -1,5 +1,5 @@
from typing import Tuple, Optional
from tinygrad.helpers import argsort, ShapeType
from tinygrad.helpers import argsort, ShapeType, DType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
@ -11,50 +11,49 @@ class Contiguous(Function):
class Cast(Function):
__slots__ = "input_dtype", "bitcast"
def forward(self, x, dtype, bitcast=False):
def forward(self, x:LazyBuffer, dtype:DType, bitcast=False):
self.input_dtype, self.bitcast = x.dtype, bitcast
return x.cast((dtype, bitcast))
def backward(self, grad_output):
def backward(self, grad_output:LazyBuffer):
return grad_output.cast((self.input_dtype, self.bitcast))
# ************* unary ops *************
class Sin(Function):
__slots__ = "x"
def forward(self, x: LazyBuffer) -> LazyBuffer:
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return x.unary_op(UnaryOps.SIN)
def backward(self, grad: LazyBuffer) -> LazyBuffer:
return self.x.const_like(math.pi / 2).binary_op(BinaryOps.SUB, self.x).unary_op(UnaryOps.SIN).binary_op(BinaryOps.MUL, grad)
def backward(self, grad:LazyBuffer) -> LazyBuffer:
return ((math.pi / 2) - self.x).unary_op(UnaryOps.SIN) * grad
# NOTE: maximum(x, 0) behaves differently where x=0
class Relu(Function):
__slots__ = "ret"
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.binary_op(BinaryOps.MAX, x.const_like(0))
self.ret = x.binary_op(BinaryOps.MAX, 0)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
mask = self.ret.const_like(0).binary_op(BinaryOps.CMPLT, self.ret)
return mask.binary_op(BinaryOps.MUL, grad_output)
return (0 < self.ret) * grad_output
class Log(Function):
__slots__ = "x"
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, x.const_like(math.log(2)))
return x.unary_op(UnaryOps.LOG2) * math.log(2)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.binary_op(BinaryOps.DIV, self.x)
return grad_output / self.x
class Exp(Function):
__slots__ = "ret"
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.binary_op(BinaryOps.MUL, x.const_like(1/math.log(2))).unary_op(UnaryOps.EXP2)
self.ret = (x * (1/math.log(2))).unary_op(UnaryOps.EXP2)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return self.ret.binary_op(BinaryOps.MUL, grad_output)
return self.ret * grad_output
class Sqrt(Function):
__slots__ = "ret"
@ -63,7 +62,7 @@ class Sqrt(Function):
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.binary_op(BinaryOps.DIV, self.ret.binary_op(BinaryOps.MUL, self.ret.const_like(2)))
return grad_output / (self.ret * 2)
# NOTE: the implicit derivative of sigmoid is not stable
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
@ -71,11 +70,11 @@ class Sqrt(Function):
class Sigmoid(Function):
__slots__ = "ret"
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.const_like(1).binary_op(BinaryOps.DIV, x.const_like(1).binary_op(BinaryOps.ADD, x.binary_op(BinaryOps.MUL, x.const_like(-1/math.log(2))).unary_op(UnaryOps.EXP2)))
self.ret = 1 / (1 + (x * (-1/math.log(2))).unary_op(UnaryOps.EXP2))
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return self.ret.binary_op(BinaryOps.MUL, self.ret.const_like(1).binary_op(BinaryOps.SUB, self.ret)).binary_op(BinaryOps.MUL, grad_output)
return (self.ret * (1 - self.ret)) * grad_output
# ************* reduce ops *************
@ -96,24 +95,19 @@ class Max(Function):
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.const_like(1).binary_op(BinaryOps.SUB, self.x.binary_op(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
# sum of locations, averaged
max_is_1s = 1.0 - (self.x < self.ret.expand(self.x.shape))
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div)
grad_output_expanded = grad_output.expand(self.x.shape)
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
return (max_is_1s / div) * grad_output.expand(self.x.shape)
# ************* binary ops *************
class Less(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
return x.binary_op(BinaryOps.CMPLT, y)
return x < y
class Add(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
return x.binary_op(BinaryOps.ADD, y)
return x + y
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output if self.needs_input_grad[0] else None, \
@ -121,31 +115,31 @@ class Add(Function):
class Sub(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
return x.binary_op(BinaryOps.SUB, y)
return x - y
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output if self.needs_input_grad[0] else None, \
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output) if self.needs_input_grad[1] else None
-grad_output if self.needs_input_grad[1] else None
class Mul(Function):
__slots__ = 'x', 'y'
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y = x, y
return x.binary_op(BinaryOps.MUL, y)
return x * y
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return self.y.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
self.x.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
return self.y * grad_output if self.needs_input_grad[0] else None, \
self.x * grad_output if self.needs_input_grad[1] else None
class Div(Function):
__slots__ = 'x', 'y'
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y = x, y
return x.binary_op(BinaryOps.DIV, y)
return x / y
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output.binary_op(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output).binary_op(BinaryOps.MUL, self.x).binary_op(BinaryOps.DIV, self.y.binary_op(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None
return grad_output / self.y if self.needs_input_grad[0] else None, \
(-grad_output * self.x) / (self.y * self.y) if self.needs_input_grad[1] else None
# ************* ternary ops *************
@ -157,8 +151,8 @@ class Where(Function):
def backward(self, grad_output:LazyBuffer):
return None, \
self.x.ternary_op(TernaryOps.WHERE, grad_output, self.x.const_like(0)) if self.needs_input_grad[1] else None, \
self.x.ternary_op(TernaryOps.WHERE, self.x.const_like(0), grad_output) if self.needs_input_grad[2] else None
self.x.ternary_op(TernaryOps.WHERE, grad_output, 0) if self.needs_input_grad[1] else None, \
self.x.ternary_op(TernaryOps.WHERE, 0, grad_output) if self.needs_input_grad[2] else None
# ************* movement ops *************