make ops more like tensor [pr] (#7352)

* make ops more like tensor [pr]

* tensor is simple math trait

* no shifts
This commit is contained in:
George Hotz 2024-10-29 15:23:41 +07:00 committed by GitHub
commit 2bf55d8eda
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 66 additions and 59 deletions

View file

@ -36,50 +36,75 @@ class MetaOps(FastEnum):
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps]
class MathTrait:
class SimpleMathTrait:
# required to implement
def alu(self:T, arg:Union[UnaryOps, BinaryOps, TernaryOps], *src) -> T: raise NotImplementedError
def const_like(self, b:ConstType|Variable|Tuple[ConstType]): raise NotImplementedError
# great functions you get!
def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x
def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
def logical_not(self): return self.ne(True)
def __neg__(self):
dtype = getattr(self, 'dtype', None)
def neg(self):
dtype: Optional[DType] = getattr(self, 'dtype', None)
assert dtype is not None, "MathTraits __neg__ requires a dtype"
return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
def __add__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))
def __radd__(self, x): return self.ufix(x).alu(BinaryOps.ADD, self)
def __sub__(self, x): return self.alu(BinaryOps.ADD, self.ufix(-x))
def __rsub__(self, x): return self.ufix(x).alu(BinaryOps.ADD, -self)
def __mul__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x))
def __rmul__(self, x): return self.ufix(x).alu(BinaryOps.MUL, self)
def __lshift__(self, x): return self.alu(BinaryOps.SHL, self.ufix(x))
def __rlshift__(self, x): return self.ufix(x).alu(BinaryOps.SHL, self)
def __rshift__(self, x): return self.alu(BinaryOps.SHR, self.ufix(x))
def __rrshift__(self, x): return self.ufix(x).alu(BinaryOps.SHR, self)
def __floordiv__(self, x): return self.alu(BinaryOps.IDIV, self.ufix(x))
def __rfloordiv__(self, x): return self.ufix(x).alu(BinaryOps.IDIV, self)
def __truediv__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x).alu(UnaryOps.RECIP))
def __rtruediv__(self, x): return self.ufix(x).alu(BinaryOps.MUL, self.alu(UnaryOps.RECIP))
def __mod__(self, x): return self.alu(BinaryOps.MOD, self.ufix(x))
def __rmod__(self, x): return self.ufix(x).alu(BinaryOps.MOD, self)
def __xor__(self, x): return self.alu(BinaryOps.XOR, self.ufix(x))
def __and__(self, x): return self.alu(BinaryOps.AND, self.ufix(x))
def __rand__(self, x): return self.ufix(x).alu(BinaryOps.AND, self)
def __or__(self, x): return self.alu(BinaryOps.OR, self.ufix(x))
def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x))
def eq(self, x): return self.ne(x).logical_not()
def add(self, x, reverse=False): return self._binop(BinaryOps.ADD, x, reverse)
def mul(self, x, reverse=False): return self._binop(BinaryOps.MUL, x, reverse)
def bitwise_and(self, x, reverse=False): return self._binop(BinaryOps.AND, x, reverse)
def bitwise_or(self, x, reverse=False): return self._binop(BinaryOps.OR, x, reverse)
def xor(self, x, reverse=False): return self._binop(BinaryOps.XOR, x, reverse)
def idiv(self, x, reverse=False): return self._binop(BinaryOps.IDIV, x, reverse)
def sub(self, x, reverse=False): return self.ufix(x).alu(BinaryOps.ADD, -self) if reverse else self.alu(BinaryOps.ADD, self.ufix(-x))
def div(self, x, reverse=False): return (self.ufix(x)*self.alu(UnaryOps.RECIP)) if reverse else (self*self.ufix(x).alu(UnaryOps.RECIP))
def __neg__(self): return self.neg()
def __add__(self, x): return self.add(x)
def __sub__(self, x): return self.sub(x)
def __mul__(self, x): return self.mul(x)
def __truediv__(self, x): return self.div(x)
def __floordiv__(self, x): return self.idiv(x)
def __and__(self, x): return self.bitwise_and(x)
def __or__(self, x): return self.bitwise_or(x)
def __xor__(self, x): return self.xor(x)
def __radd__(self, x): return self.add(x, True)
def __rsub__(self, x): return self.sub(x, True)
def __rmul__(self, x): return self.mul(x, True)
def __rtruediv__(self, x): return self.div(x, True)
def __rfloordiv__(self, x): return self.idiv(x, True)
def __rand__(self, x): return self.bitwise_and(x, True)
def __ror__(self, x): return self.bitwise_or(x, True)
def __rxor__(self, x): return self.xor(x, True)
def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x))
def gt(self, x): return self.ufix(x).alu(BinaryOps.CMPLT, self)
def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x))
def ge(self, x): return self.lt(x).logical_not()
def le(self, x): return self.gt(x).logical_not()
# NOTE: __eq__ can't be overridden, and means the same thing as is
def __ne__(self, x): return self.ne(x)
def eq(self, x): return self.ne(x).logical_not()
def __lt__(self, x): return self.lt(x)
def __gt__(self, x): return self.gt(x)
def __ne__(self, x): return self.ne(x)
def __ge__(self, x): return self.ge(x)
def __le__(self, x): return self.le(x)
# NOTE: __eq__ isn't overridden, and means the same thing as is by default
class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method
# TODO: move to Tensor when new backward is done
def lshift(self, x, reverse=False): return self._binop(BinaryOps.SHL, x, reverse)
def rshift(self, x, reverse=False): return self._binop(BinaryOps.SHR, x, reverse)
def __lshift__(self, x): return self.lshift(x)
def __rshift__(self, x): return self.rshift(x)
def __rlshift__(self, x): return self.lshift(x, True)
def __rrshift__(self, x): return self.rshift(x, True)
# not in Tensor
def __mod__(self, x): return self.alu(BinaryOps.MOD, self.ufix(x))
def __rmod__(self, x): return self.ufix(x).alu(BinaryOps.MOD, self)
def max(self, x): return self.alu(BinaryOps.MAX, self.ufix(x))
def min(self, x): return -(-self).max(-x)
def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y)

View file

@ -9,7 +9,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch
from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, UOps, BinaryOps, sint, Variable
from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, UOps, BinaryOps, sint, Variable, SimpleMathTrait
from tinygrad.device import Device, Buffer, BufferOptions
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.engine.realize import run_schedule
@ -99,7 +99,7 @@ def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
ReductionStr = Literal["mean", "sum", "none"]
class Tensor:
class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
"""
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
@ -1137,16 +1137,16 @@ class Tensor:
# for advanced setitem, returns whole tensor with indices replaced
if v is not None:
v = v.cast(self.dtype)._broadcast_to(_broadcast_shape(ret.shape, v.shape))
vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(ret.shape, v.shape))
# add back reduced dims from sum
for dim in sum_axis: v = v.unsqueeze(dim)
for dim in sum_axis: vb = vb.unsqueeze(dim)
# axis to be reduced to match self.shape
axis = tuple(range(first_dim, first_dim + len(big_shape)))
# apply mask to v(broadcasted) and reduce such that if v contains repeated indices the last one remains
v = v * mask
for dim in axis: v = functools.reduce(lambda x,y: y.where(y, x), v.split(1, dim))
vb = vb * mask
for dim in axis: vb = functools.reduce(lambda x,y: y.where(y, x), vb.split(1, dim))
# reduce mask and select from v(get rid of extra dims from reduce) for each True element in mask else select from self
ret = mask.any(axis).where(v.squeeze(), self)
ret = mask.any(axis).where(vb.squeeze(), self)
return ret
@ -3018,31 +3018,14 @@ class Tensor:
# ***** op wrappers *****
def __neg__(self) -> Tensor: return self.neg()
def __add__(self, x) -> Tensor: return self.add(x)
def __sub__(self, x) -> Tensor: return self.sub(x)
def __mul__(self, x) -> Tensor: return self.mul(x)
def __pow__(self, x) -> Tensor: return self.pow(x)
def __truediv__(self, x) -> Tensor: return self.div(x)
def __floordiv__(self, x) -> Tensor: return self.idiv(x)
def __matmul__(self, x) -> Tensor: return self.matmul(x)
def __and__(self, x) -> Tensor: return self.bitwise_and(x)
def __or__(self, x) -> Tensor: return self.bitwise_or(x)
def __xor__(self, x) -> Tensor: return self.xor(x)
def __lshift__(self, x) -> Tensor: return self.lshift(x)
def __rshift__(self, x) -> Tensor: return self.rshift(x)
def __radd__(self, x) -> Tensor: return self.add(x, True)
def __rsub__(self, x) -> Tensor: return self.sub(x, True)
def __rmul__(self, x) -> Tensor: return self.mul(x, True)
def __pow__(self, x) -> Tensor: return self.pow(x)
def __matmul__(self, x) -> Tensor: return self.matmul(x)
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
def __rfloordiv__(self, x) -> Tensor: return self.idiv(x, True)
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
def __rand__(self, x) -> Tensor: return self.bitwise_and(x, True)
def __ror__(self, x) -> Tensor: return self.bitwise_or(x, True)
def __rxor__(self, x) -> Tensor: return self.xor(x, True)
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
@ -3057,12 +3040,11 @@ class Tensor:
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
def __ge__(self, x) -> Tensor: return (self<x).logical_not()
def __le__(self, x) -> Tensor: return (self>x).logical_not()
def __ne__(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x)) # type: ignore[override]
def __eq__(self, x) -> Tensor: return (self!=x).logical_not() # type: ignore[override]
def lt(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
def gt(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
def ne(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x)) # type: ignore[override]
def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override]
# ***** functional nn ops *****