remove Function class [pr] (#8753)

* remove Function class [pr]

* actually remove function

* fix docs
This commit is contained in:
George Hotz 2025-01-26 18:58:02 +09:00 committed by GitHub
commit a6e496b195
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 53 additions and 203 deletions

View file

@ -9,7 +9,7 @@ There is a good [bunch of tutorials](https://mesozoic-egg.github.io/tinygrad-not
## Frontend
Everything in [Tensor](../tensor/index.md) is syntactic sugar around [function.py](function.md), where the forwards and backwards passes are implemented for the different functions. There's about 25 of them, implemented using about 20 basic ops. Those basic ops go on to construct a graph of [UOps](../developer/uop.md).
Everything in [Tensor](../tensor/index.md) is syntactic sugar around constructing a graph of [UOps](../developer/uop.md).
The `UOp` graph specifies the compute in terms of low level tinygrad ops. Not all UOps will actually become realized. There's two types of UOps, base and view. base contains compute into a contiguous buffer, and view is a view (specified by a ShapeTracker). Inputs to a base can be either base or view, inputs to a view can only be a single base.

View file

@ -1,33 +0,0 @@
::: tinygrad.function
options:
members: [
"Contiguous",
"ContiguousBackward",
"Cast",
"Neg",
"Reciprocal",
"Sin",
"Relu",
"Log",
"Exp",
"Sqrt",
"Sigmoid",
"Sign",
"Less",
"Eq",
"Xor",
"Add",
"Sub",
"Mul",
"Div",
"Where",
"Sum",
"Max",
"Expand",
"Reshape",
"Permute",
"Pad",
"Shrink",
"Flip",
]
show_source: false

View file

@ -22,7 +22,6 @@ nav:
- Runtime: runtime.md
- Developer:
- Intro: developer/developer.md
- Function (autodiff): developer/function.md
- UOp: developer/uop.md
- Runtime:
- developer/runtime.md

View file

@ -1,108 +0,0 @@
"""This is where the forwards and backwards passes live."""
import math
from tinygrad.dtype import DType
from tinygrad.ops import Ops, sint, UOp
from tinygrad.tensor import Function
class Contiguous(Function):
def forward(self, x:UOp) -> UOp: return x.contiguous()
class ContiguousBackward(Function):
def forward(self, x:UOp) -> UOp: return x.contiguous_backward()
class Cast(Function):
def forward(self, x:UOp, dtype:DType, bitcast:bool=False) -> UOp: return x.bitcast(dtype) if bitcast else x.cast(dtype)
# ************* unary ops *************
class Reciprocal(Function):
def forward(self, x:UOp) -> UOp: return x.reciprocal()
class Sin(Function):
def forward(self, x:UOp) -> UOp: return x.sin()
class Relu(Function):
def forward(self, x:UOp) -> UOp: return (x>0).where(x, 0)
class Log(Function):
def forward(self, x:UOp) -> UOp: return x.log2() * math.log(2)
class Exp(Function):
def forward(self, x:UOp) -> UOp: return (x * (1/math.log(2))).exp2()
class Sqrt(Function):
def forward(self, x:UOp) -> UOp: return x.sqrt()
class Sign(Function):
# NOTE: the x*0 is to match torch behavior without function.py
def forward(self, x:UOp) -> UOp: return x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) + x*0
# ************* binary ops *************
class Less(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x<y
class Neq(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x.ne(y)
class Xor(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x^y
class BitwiseAnd(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x&y
class BitwiseOr(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x|y
class Threefry(Function):
def forward(self, x:UOp, seed:UOp) -> UOp: return x.threefry(seed)
class Add(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x+y
class Mul(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x * y
class IDiv(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x // y
class Mod(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x % y
# ************* ternary ops *************
class Where(Function):
def forward(self, x:UOp, y:UOp, z:UOp) -> UOp: return x.where(y, z)
# ************* reduce ops *************
class Sum(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.r(Ops.ADD, axis)
class Prod(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.r(Ops.MUL, axis)
class Max(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.r(Ops.MAX, axis)
# ************* movement ops *************
# NOTE: this is sum in reverse
class Expand(Function):
def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp: return x.expand(shape)
class Reshape(Function):
def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp: return x.reshape(shape)
class Permute(Function):
def forward(self, x:UOp, order:tuple[int, ...]) -> UOp: return x.permute(order)
class Pad(Function):
def forward(self, x:UOp, arg:tuple[tuple[int, int], ...]) -> UOp: return x.pad(arg)
class Shrink(Function):
def forward(self, x:UOp, arg:tuple[tuple[sint, sint], ...]) -> UOp: return x.shrink(arg)
class Flip(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.stride(tuple([-1 if i in axis else 1 for i in range(len(x.shape))]))

View file

@ -2,7 +2,7 @@
from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
from typing import List, Tuple, Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
@ -42,26 +42,7 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
if s is ns: continue
t.lazydata = ns
# **** start with two base classes, Tensor and Function ****
class Function:
def __init__(self, device:Union[str, tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None):
self.device = device
self.needs_input_grad = [t.requires_grad for t in tensors]
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
if self.requires_grad: self.parents = tensors
self.metadata = metadata
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
@classmethod
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
ctx = fxn(x[0].device, *x, metadata=_METADATA.get())
ret = Tensor.__new__(Tensor)
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
return ret
import tinygrad.function as F
# **** Tensor helper functions ****
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None):
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
@ -239,6 +220,17 @@ class Tensor(SimpleMathTrait):
@property
def dtype(self) -> DType: return self.lazydata.dtype
def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
ret = Tensor.__new__(Tensor)
needs_input_grad = [t.requires_grad for t in (self,)+x]
ret.requires_grad, ret.grad = True if any(needs_input_grad) else None if None in needs_input_grad else False, None
ret.lazydata = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
return ret
def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
lhs,rhs = self._broadcasted(x, reverse)
return lhs._apply_uop(fxn, rhs)
# ***** data handlers ****
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
@ -497,7 +489,7 @@ class Tensor(SimpleMathTrait):
@staticmethod
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
return counts0.cat(counts1)
@ -961,7 +953,7 @@ class Tensor(SimpleMathTrait):
# resolve -1
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self
def expand(self, shape, *args) -> Tensor:
"""
@ -994,7 +986,7 @@ class Tensor(SimpleMathTrait):
"""
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
return F.Permute.apply(self, order=order_arg)
return self._apply_uop(UOp.permute, arg=order_arg)
def flip(self, axis, *args) -> Tensor:
"""
@ -1014,7 +1006,7 @@ class Tensor(SimpleMathTrait):
"""
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
return F.Flip.apply(self, axis=axis_arg)
return self._apply_uop(UOp.stride, arg=tuple([-1 if i in axis_arg else 1 for i in range(len(self.shape))]))
def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor:
"""
@ -1034,7 +1026,7 @@ class Tensor(SimpleMathTrait):
```
"""
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
return F.Shrink.apply(self, arg=tuple(shrink_arg))
return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
"""
@ -1078,7 +1070,8 @@ class Tensor(SimpleMathTrait):
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
if mode == "constant":
def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0,v)
def _constant(x:Tensor,px,v):
return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
_constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
@ -1568,10 +1561,10 @@ class Tensor(SimpleMathTrait):
# ***** reduce ops *****
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
def _reduce(self, op:Ops, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
if self.ndim == 0: axis = ()
ret = fxn.apply(self, axis=axis)
ret = self._apply_uop(UOp.r, op=op, axis=axis)
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
@ -1598,7 +1591,7 @@ class Tensor(SimpleMathTrait):
print(t.sum(axis=1).numpy())
```
"""
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(F.Sum, axis, keepdim)
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim)
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
@ -1625,7 +1618,7 @@ class Tensor(SimpleMathTrait):
print(t.prod(axis=1).numpy())
```
"""
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(F.Prod, axis, keepdim)
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
"""
@ -1648,7 +1641,7 @@ class Tensor(SimpleMathTrait):
print(t.max(axis=1, keepdim=True).numpy())
```
"""
return self._reduce(F.Max, axis, keepdim)
return self._reduce(Ops.MAX, axis, keepdim)
def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
@ -2485,7 +2478,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([False, True]).logical_not().numpy())
```
"""
return F.Neq.apply(*self.cast(dtypes.bool)._broadcasted(True))
return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True)
def neg(self):
"""
Negates the tensor element-wise.
@ -2499,12 +2492,12 @@ class Tensor(SimpleMathTrait):
"""
Returns a contiguous tensor.
"""
return F.Contiguous.apply(self)
return self._apply_uop(UOp.contiguous)
def contiguous_backward(self):
"""
Inserts a contiguous operation in the backward pass.
"""
return F.ContiguousBackward.apply(self)
return self._apply_uop(UOp.contiguous_backward)
def log(self):
"""
Computes the natural logarithm element-wise.
@ -2515,7 +2508,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([1., 2., 4., 8.]).log().numpy())
```
"""
return F.Log.apply(self.cast(least_upper_float(self.dtype)))
return self.log2()*math.log(2)
def log2(self):
"""
Computes the base-2 logarithm element-wise.
@ -2526,7 +2519,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([1., 2., 4., 8.]).log2().numpy())
```
"""
return self.log()/math.log(2)
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2)
def exp(self):
"""
Computes the exponential function element-wise.
@ -2537,7 +2530,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([0., 1., 2., 3.]).exp().numpy())
```
"""
return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
return self.mul(1/math.log(2)).exp2()
def exp2(self):
"""
Computes the base-2 exponential function element-wise.
@ -2548,8 +2541,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([0., 1., 2., 3.]).exp2().numpy())
```
"""
return F.Exp.apply(self*math.log(2))
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2)
def relu(self):
"""
Applies the Rectified Linear Unit (ReLU) function element-wise.
@ -2560,7 +2552,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
```
"""
return F.Relu.apply(self)
return (self>0).where(self, 0)
def sigmoid(self):
"""
@ -2596,7 +2588,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
```
"""
return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt)
def rsqrt(self):
"""
Computes the reciprocal of the square root of the tensor element-wise.
@ -2614,7 +2606,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
```
"""
return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin)
def cos(self):
"""
Computes the cosine of the tensor element-wise.
@ -2773,7 +2765,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
```
"""
return F.Sign.apply(self)
return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0
def abs(self):
"""
Computes the absolute value of the tensor element-wise.
@ -2791,7 +2783,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
```
"""
return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.reciprocal)
# ***** activation functions *****
@ -3069,7 +3061,7 @@ class Tensor(SimpleMathTrait):
# for each dimension, check either dim is 1, or it does not change
if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)):
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
return F.Expand.apply(self.reshape(shape), shape=new_shape)
return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape)
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
x: Tensor = self
@ -3113,7 +3105,7 @@ class Tensor(SimpleMathTrait):
print(t.add(Tensor([[2.0], [3.5]])).numpy())
```
"""
return F.Add.apply(*self._broadcasted(x, reverse))
return self._apply_broadcasted_uop(UOp.add, x, reverse)
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
@ -3154,7 +3146,7 @@ class Tensor(SimpleMathTrait):
print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
```
"""
return F.Mul.apply(*self._broadcasted(x, reverse))
return self._apply_broadcasted_uop(UOp.mul, x, reverse)
def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
@ -3167,7 +3159,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy())
```
"""
return F.IDiv.apply(*self._broadcasted(x, reverse))
return self._apply_broadcasted_uop(UOp.idiv, x, reverse)
def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
@ -3202,7 +3194,7 @@ class Tensor(SimpleMathTrait):
```
"""
a, b = self._broadcasted(x, reverse)
return (r := F.Mod.apply(a, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0)))
return (r := a._apply_uop(UOp.mod, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0)))
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
@ -3218,7 +3210,7 @@ class Tensor(SimpleMathTrait):
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return F.Xor.apply(*self._broadcasted(x, reverse))
return self._apply_broadcasted_uop(UOp.xor, x, reverse)
def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
@ -3233,7 +3225,7 @@ class Tensor(SimpleMathTrait):
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return F.BitwiseAnd.apply(*self._broadcasted(x, reverse))
return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse)
def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
@ -3248,7 +3240,7 @@ class Tensor(SimpleMathTrait):
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return F.BitwiseOr.apply(*self._broadcasted(x, reverse))
return self._apply_broadcasted_uop(UOp.bitwise_or, x, reverse)
def bitwise_not(self) -> Tensor:
"""
@ -3379,7 +3371,7 @@ class Tensor(SimpleMathTrait):
elif isinstance(y, Tensor): y, x = y._broadcasted(x)
cond, x = self._broadcasted(x, match_dtype=False)
cond, y = cond._broadcasted(y, match_dtype=False)
return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y))
return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y))
def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
@ -3409,9 +3401,9 @@ class Tensor(SimpleMathTrait):
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
def ne(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x))
def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False)
def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True)
def ne(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.ne, x, False)
def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override]
@ -3757,8 +3749,8 @@ class Tensor(SimpleMathTrait):
"""
if (dt:=to_dtype(dtype)) in {dtypes.uint8, dtypes.uint16} and dtypes.is_float(self.dtype):
# NOTE: values within the int32 range and outside the unsigned dtype range will cause values to wrap around
return F.Cast.apply(F.Cast.apply(self, dtype=dtypes.int32), dtype=dt)
return self if self.dtype == dt else F.Cast.apply(self, dtype=dt)
return self._apply_uop(UOp.cast, dtype=dtypes.int32)._apply_uop(UOp.cast, dtype=dt)
return self if self.dtype == dt else self._apply_uop(UOp.cast, dtype=dt)
def bitcast(self, dtype:DTypeLike) -> Tensor:
"""
@ -3783,7 +3775,7 @@ class Tensor(SimpleMathTrait):
tmp = self.bitcast(old_uint)
if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype)
return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != dt else self
return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self
def float(self) -> Tensor:
"""