mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove Function class [pr] (#8753)
* remove Function class [pr] * actually remove function * fix docs
This commit is contained in:
parent
ac70f63d4b
commit
a6e496b195
5 changed files with 53 additions and 203 deletions
|
|
@ -9,7 +9,7 @@ There is a good [bunch of tutorials](https://mesozoic-egg.github.io/tinygrad-not
|
|||
|
||||
## Frontend
|
||||
|
||||
Everything in [Tensor](../tensor/index.md) is syntactic sugar around [function.py](function.md), where the forwards and backwards passes are implemented for the different functions. There's about 25 of them, implemented using about 20 basic ops. Those basic ops go on to construct a graph of [UOps](../developer/uop.md).
|
||||
Everything in [Tensor](../tensor/index.md) is syntactic sugar around constructing a graph of [UOps](../developer/uop.md).
|
||||
|
||||
The `UOp` graph specifies the compute in terms of low level tinygrad ops. Not all UOps will actually become realized. There's two types of UOps, base and view. base contains compute into a contiguous buffer, and view is a view (specified by a ShapeTracker). Inputs to a base can be either base or view, inputs to a view can only be a single base.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,33 +0,0 @@
|
|||
::: tinygrad.function
|
||||
options:
|
||||
members: [
|
||||
"Contiguous",
|
||||
"ContiguousBackward",
|
||||
"Cast",
|
||||
"Neg",
|
||||
"Reciprocal",
|
||||
"Sin",
|
||||
"Relu",
|
||||
"Log",
|
||||
"Exp",
|
||||
"Sqrt",
|
||||
"Sigmoid",
|
||||
"Sign",
|
||||
"Less",
|
||||
"Eq",
|
||||
"Xor",
|
||||
"Add",
|
||||
"Sub",
|
||||
"Mul",
|
||||
"Div",
|
||||
"Where",
|
||||
"Sum",
|
||||
"Max",
|
||||
"Expand",
|
||||
"Reshape",
|
||||
"Permute",
|
||||
"Pad",
|
||||
"Shrink",
|
||||
"Flip",
|
||||
]
|
||||
show_source: false
|
||||
|
|
@ -22,7 +22,6 @@ nav:
|
|||
- Runtime: runtime.md
|
||||
- Developer:
|
||||
- Intro: developer/developer.md
|
||||
- Function (autodiff): developer/function.md
|
||||
- UOp: developer/uop.md
|
||||
- Runtime:
|
||||
- developer/runtime.md
|
||||
|
|
|
|||
|
|
@ -1,108 +0,0 @@
|
|||
"""This is where the forwards and backwards passes live."""
|
||||
import math
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.ops import Ops, sint, UOp
|
||||
from tinygrad.tensor import Function
|
||||
|
||||
class Contiguous(Function):
|
||||
def forward(self, x:UOp) -> UOp: return x.contiguous()
|
||||
|
||||
class ContiguousBackward(Function):
|
||||
def forward(self, x:UOp) -> UOp: return x.contiguous_backward()
|
||||
|
||||
class Cast(Function):
|
||||
def forward(self, x:UOp, dtype:DType, bitcast:bool=False) -> UOp: return x.bitcast(dtype) if bitcast else x.cast(dtype)
|
||||
|
||||
# ************* unary ops *************
|
||||
|
||||
class Reciprocal(Function):
|
||||
def forward(self, x:UOp) -> UOp: return x.reciprocal()
|
||||
|
||||
class Sin(Function):
|
||||
def forward(self, x:UOp) -> UOp: return x.sin()
|
||||
|
||||
class Relu(Function):
|
||||
def forward(self, x:UOp) -> UOp: return (x>0).where(x, 0)
|
||||
|
||||
class Log(Function):
|
||||
def forward(self, x:UOp) -> UOp: return x.log2() * math.log(2)
|
||||
|
||||
class Exp(Function):
|
||||
def forward(self, x:UOp) -> UOp: return (x * (1/math.log(2))).exp2()
|
||||
|
||||
class Sqrt(Function):
|
||||
def forward(self, x:UOp) -> UOp: return x.sqrt()
|
||||
|
||||
class Sign(Function):
|
||||
# NOTE: the x*0 is to match torch behavior without function.py
|
||||
def forward(self, x:UOp) -> UOp: return x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) + x*0
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
class Less(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x<y
|
||||
|
||||
class Neq(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x.ne(y)
|
||||
|
||||
class Xor(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x^y
|
||||
|
||||
class BitwiseAnd(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x&y
|
||||
|
||||
class BitwiseOr(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x|y
|
||||
|
||||
class Threefry(Function):
|
||||
def forward(self, x:UOp, seed:UOp) -> UOp: return x.threefry(seed)
|
||||
|
||||
class Add(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x+y
|
||||
|
||||
class Mul(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x * y
|
||||
|
||||
class IDiv(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x // y
|
||||
|
||||
class Mod(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x % y
|
||||
|
||||
# ************* ternary ops *************
|
||||
|
||||
class Where(Function):
|
||||
def forward(self, x:UOp, y:UOp, z:UOp) -> UOp: return x.where(y, z)
|
||||
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.r(Ops.ADD, axis)
|
||||
|
||||
class Prod(Function):
|
||||
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.r(Ops.MUL, axis)
|
||||
|
||||
class Max(Function):
|
||||
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.r(Ops.MAX, axis)
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
# NOTE: this is sum in reverse
|
||||
class Expand(Function):
|
||||
def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp: return x.expand(shape)
|
||||
|
||||
class Reshape(Function):
|
||||
def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp: return x.reshape(shape)
|
||||
|
||||
class Permute(Function):
|
||||
def forward(self, x:UOp, order:tuple[int, ...]) -> UOp: return x.permute(order)
|
||||
|
||||
class Pad(Function):
|
||||
def forward(self, x:UOp, arg:tuple[tuple[int, int], ...]) -> UOp: return x.pad(arg)
|
||||
|
||||
class Shrink(Function):
|
||||
def forward(self, x:UOp, arg:tuple[tuple[sint, sint], ...]) -> UOp: return x.shrink(arg)
|
||||
|
||||
class Flip(Function):
|
||||
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.stride(tuple([-1 if i in axis else 1 for i in range(len(x.shape))]))
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
|
||||
|
|
@ -42,26 +42,7 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
|
|||
if s is ns: continue
|
||||
t.lazydata = ns
|
||||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
class Function:
|
||||
def __init__(self, device:Union[str, tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None):
|
||||
self.device = device
|
||||
self.needs_input_grad = [t.requires_grad for t in tensors]
|
||||
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
|
||||
if self.requires_grad: self.parents = tensors
|
||||
self.metadata = metadata
|
||||
|
||||
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
|
||||
|
||||
@classmethod
|
||||
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
|
||||
ctx = fxn(x[0].device, *x, metadata=_METADATA.get())
|
||||
ret = Tensor.__new__(Tensor)
|
||||
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
|
||||
return ret
|
||||
|
||||
import tinygrad.function as F
|
||||
# **** Tensor helper functions ****
|
||||
|
||||
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None):
|
||||
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
|
||||
|
|
@ -239,6 +220,17 @@ class Tensor(SimpleMathTrait):
|
|||
@property
|
||||
def dtype(self) -> DType: return self.lazydata.dtype
|
||||
|
||||
def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
|
||||
ret = Tensor.__new__(Tensor)
|
||||
needs_input_grad = [t.requires_grad for t in (self,)+x]
|
||||
ret.requires_grad, ret.grad = True if any(needs_input_grad) else None if None in needs_input_grad else False, None
|
||||
ret.lazydata = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
|
||||
return ret
|
||||
|
||||
def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
lhs,rhs = self._broadcasted(x, reverse)
|
||||
return lhs._apply_uop(fxn, rhs)
|
||||
|
||||
# ***** data handlers ****
|
||||
|
||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
|
||||
|
|
@ -497,7 +489,7 @@ class Tensor(SimpleMathTrait):
|
|||
@staticmethod
|
||||
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
|
||||
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
|
||||
x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
|
||||
x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
|
||||
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
|
||||
return counts0.cat(counts1)
|
||||
|
||||
|
|
@ -961,7 +953,7 @@ class Tensor(SimpleMathTrait):
|
|||
# resolve -1
|
||||
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
||||
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
||||
return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
|
||||
return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self
|
||||
|
||||
def expand(self, shape, *args) -> Tensor:
|
||||
"""
|
||||
|
|
@ -994,7 +986,7 @@ class Tensor(SimpleMathTrait):
|
|||
"""
|
||||
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
|
||||
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
|
||||
return F.Permute.apply(self, order=order_arg)
|
||||
return self._apply_uop(UOp.permute, arg=order_arg)
|
||||
|
||||
def flip(self, axis, *args) -> Tensor:
|
||||
"""
|
||||
|
|
@ -1014,7 +1006,7 @@ class Tensor(SimpleMathTrait):
|
|||
"""
|
||||
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
||||
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
||||
return F.Flip.apply(self, axis=axis_arg)
|
||||
return self._apply_uop(UOp.stride, arg=tuple([-1 if i in axis_arg else 1 for i in range(len(self.shape))]))
|
||||
|
||||
def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor:
|
||||
"""
|
||||
|
|
@ -1034,7 +1026,7 @@ class Tensor(SimpleMathTrait):
|
|||
```
|
||||
"""
|
||||
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
|
||||
return F.Shrink.apply(self, arg=tuple(shrink_arg))
|
||||
return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
|
||||
|
||||
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
|
||||
"""
|
||||
|
|
@ -1078,7 +1070,8 @@ class Tensor(SimpleMathTrait):
|
|||
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
|
||||
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
|
||||
if mode == "constant":
|
||||
def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0,v)
|
||||
def _constant(x:Tensor,px,v):
|
||||
return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
|
||||
return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
|
||||
_constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
|
|
@ -1568,10 +1561,10 @@ class Tensor(SimpleMathTrait):
|
|||
|
||||
# ***** reduce ops *****
|
||||
|
||||
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
|
||||
def _reduce(self, op:Ops, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
|
||||
axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
|
||||
if self.ndim == 0: axis = ()
|
||||
ret = fxn.apply(self, axis=axis)
|
||||
ret = self._apply_uop(UOp.r, op=op, axis=axis)
|
||||
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
|
||||
|
||||
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
||||
|
|
@ -1598,7 +1591,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(t.sum(axis=1).numpy())
|
||||
```
|
||||
"""
|
||||
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(F.Sum, axis, keepdim)
|
||||
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim)
|
||||
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
||||
|
||||
def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
||||
|
|
@ -1625,7 +1618,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(t.prod(axis=1).numpy())
|
||||
```
|
||||
"""
|
||||
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(F.Prod, axis, keepdim)
|
||||
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
|
||||
|
||||
def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
||||
"""
|
||||
|
|
@ -1648,7 +1641,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(t.max(axis=1, keepdim=True).numpy())
|
||||
```
|
||||
"""
|
||||
return self._reduce(F.Max, axis, keepdim)
|
||||
return self._reduce(Ops.MAX, axis, keepdim)
|
||||
|
||||
def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
|
||||
|
||||
|
|
@ -2485,7 +2478,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(Tensor([False, True]).logical_not().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Neq.apply(*self.cast(dtypes.bool)._broadcasted(True))
|
||||
return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True)
|
||||
def neg(self):
|
||||
"""
|
||||
Negates the tensor element-wise.
|
||||
|
|
@ -2499,12 +2492,12 @@ class Tensor(SimpleMathTrait):
|
|||
"""
|
||||
Returns a contiguous tensor.
|
||||
"""
|
||||
return F.Contiguous.apply(self)
|
||||
return self._apply_uop(UOp.contiguous)
|
||||
def contiguous_backward(self):
|
||||
"""
|
||||
Inserts a contiguous operation in the backward pass.
|
||||
"""
|
||||
return F.ContiguousBackward.apply(self)
|
||||
return self._apply_uop(UOp.contiguous_backward)
|
||||
def log(self):
|
||||
"""
|
||||
Computes the natural logarithm element-wise.
|
||||
|
|
@ -2515,7 +2508,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(Tensor([1., 2., 4., 8.]).log().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Log.apply(self.cast(least_upper_float(self.dtype)))
|
||||
return self.log2()*math.log(2)
|
||||
def log2(self):
|
||||
"""
|
||||
Computes the base-2 logarithm element-wise.
|
||||
|
|
@ -2526,7 +2519,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(Tensor([1., 2., 4., 8.]).log2().numpy())
|
||||
```
|
||||
"""
|
||||
return self.log()/math.log(2)
|
||||
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2)
|
||||
def exp(self):
|
||||
"""
|
||||
Computes the exponential function element-wise.
|
||||
|
|
@ -2537,7 +2530,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(Tensor([0., 1., 2., 3.]).exp().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
|
||||
return self.mul(1/math.log(2)).exp2()
|
||||
def exp2(self):
|
||||
"""
|
||||
Computes the base-2 exponential function element-wise.
|
||||
|
|
@ -2548,8 +2541,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(Tensor([0., 1., 2., 3.]).exp2().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Exp.apply(self*math.log(2))
|
||||
|
||||
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2)
|
||||
def relu(self):
|
||||
"""
|
||||
Applies the Rectified Linear Unit (ReLU) function element-wise.
|
||||
|
|
@ -2560,7 +2552,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Relu.apply(self)
|
||||
return (self>0).where(self, 0)
|
||||
|
||||
def sigmoid(self):
|
||||
"""
|
||||
|
|
@ -2596,7 +2588,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
|
||||
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt)
|
||||
def rsqrt(self):
|
||||
"""
|
||||
Computes the reciprocal of the square root of the tensor element-wise.
|
||||
|
|
@ -2614,7 +2606,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
|
||||
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin)
|
||||
def cos(self):
|
||||
"""
|
||||
Computes the cosine of the tensor element-wise.
|
||||
|
|
@ -2773,7 +2765,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Sign.apply(self)
|
||||
return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0
|
||||
def abs(self):
|
||||
"""
|
||||
Computes the absolute value of the tensor element-wise.
|
||||
|
|
@ -2791,7 +2783,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
|
||||
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.reciprocal)
|
||||
|
||||
# ***** activation functions *****
|
||||
|
||||
|
|
@ -3069,7 +3061,7 @@ class Tensor(SimpleMathTrait):
|
|||
# for each dimension, check either dim is 1, or it does not change
|
||||
if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)):
|
||||
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
|
||||
return F.Expand.apply(self.reshape(shape), shape=new_shape)
|
||||
return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape)
|
||||
|
||||
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
|
||||
x: Tensor = self
|
||||
|
|
@ -3113,7 +3105,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(t.add(Tensor([[2.0], [3.5]])).numpy())
|
||||
```
|
||||
"""
|
||||
return F.Add.apply(*self._broadcasted(x, reverse))
|
||||
return self._apply_broadcasted_uop(UOp.add, x, reverse)
|
||||
|
||||
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
|
|
@ -3154,7 +3146,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
|
||||
```
|
||||
"""
|
||||
return F.Mul.apply(*self._broadcasted(x, reverse))
|
||||
return self._apply_broadcasted_uop(UOp.mul, x, reverse)
|
||||
|
||||
def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
|
|
@ -3167,7 +3159,7 @@ class Tensor(SimpleMathTrait):
|
|||
print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy())
|
||||
```
|
||||
"""
|
||||
return F.IDiv.apply(*self._broadcasted(x, reverse))
|
||||
return self._apply_broadcasted_uop(UOp.idiv, x, reverse)
|
||||
|
||||
def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
|
|
@ -3202,7 +3194,7 @@ class Tensor(SimpleMathTrait):
|
|||
```
|
||||
"""
|
||||
a, b = self._broadcasted(x, reverse)
|
||||
return (r := F.Mod.apply(a, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0)))
|
||||
return (r := a._apply_uop(UOp.mod, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0)))
|
||||
|
||||
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
|
|
@ -3218,7 +3210,7 @@ class Tensor(SimpleMathTrait):
|
|||
```
|
||||
"""
|
||||
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
||||
return F.Xor.apply(*self._broadcasted(x, reverse))
|
||||
return self._apply_broadcasted_uop(UOp.xor, x, reverse)
|
||||
|
||||
def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
|
|
@ -3233,7 +3225,7 @@ class Tensor(SimpleMathTrait):
|
|||
```
|
||||
"""
|
||||
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
||||
return F.BitwiseAnd.apply(*self._broadcasted(x, reverse))
|
||||
return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse)
|
||||
|
||||
def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
|
|
@ -3248,7 +3240,7 @@ class Tensor(SimpleMathTrait):
|
|||
```
|
||||
"""
|
||||
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
||||
return F.BitwiseOr.apply(*self._broadcasted(x, reverse))
|
||||
return self._apply_broadcasted_uop(UOp.bitwise_or, x, reverse)
|
||||
|
||||
def bitwise_not(self) -> Tensor:
|
||||
"""
|
||||
|
|
@ -3379,7 +3371,7 @@ class Tensor(SimpleMathTrait):
|
|||
elif isinstance(y, Tensor): y, x = y._broadcasted(x)
|
||||
cond, x = self._broadcasted(x, match_dtype=False)
|
||||
cond, y = cond._broadcasted(y, match_dtype=False)
|
||||
return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y))
|
||||
return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y))
|
||||
|
||||
def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
|
||||
|
||||
|
|
@ -3409,9 +3401,9 @@ class Tensor(SimpleMathTrait):
|
|||
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
|
||||
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
|
||||
|
||||
def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
|
||||
def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
|
||||
def ne(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x))
|
||||
def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False)
|
||||
def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True)
|
||||
def ne(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.ne, x, False)
|
||||
|
||||
def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override]
|
||||
|
||||
|
|
@ -3757,8 +3749,8 @@ class Tensor(SimpleMathTrait):
|
|||
"""
|
||||
if (dt:=to_dtype(dtype)) in {dtypes.uint8, dtypes.uint16} and dtypes.is_float(self.dtype):
|
||||
# NOTE: values within the int32 range and outside the unsigned dtype range will cause values to wrap around
|
||||
return F.Cast.apply(F.Cast.apply(self, dtype=dtypes.int32), dtype=dt)
|
||||
return self if self.dtype == dt else F.Cast.apply(self, dtype=dt)
|
||||
return self._apply_uop(UOp.cast, dtype=dtypes.int32)._apply_uop(UOp.cast, dtype=dt)
|
||||
return self if self.dtype == dt else self._apply_uop(UOp.cast, dtype=dt)
|
||||
|
||||
def bitcast(self, dtype:DTypeLike) -> Tensor:
|
||||
"""
|
||||
|
|
@ -3783,7 +3775,7 @@ class Tensor(SimpleMathTrait):
|
|||
tmp = self.bitcast(old_uint)
|
||||
if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype)
|
||||
return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
|
||||
return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != dt else self
|
||||
return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self
|
||||
|
||||
def float(self) -> Tensor:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue