move mixins to mixin dir (#13105)

* move mixins to mixin dir

* math
This commit is contained in:
George Hotz 2025-11-05 10:18:33 -08:00 committed by GitHub
commit 2d4f01fda0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 93 additions and 86 deletions

View file

@ -17,7 +17,7 @@ M = getenv("M", N)
K = getenv("K", N)
CNT = getenv("CNT", 10)
atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype_in, (1e-4, 3e-2))
atol, rtol = {dtypes.half:{1e-3, 1e-2}, dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype_in, (1e-4, 3e-2))
ATOL, RTOL = getenv("ATOL", atol), getenv("RTOL", rtol)
INT_LOW = getenv("INT_LOW", 0)

View file

@ -32,6 +32,7 @@ setup(name='tinygrad',
'tinygrad.codegen.opt',
'tinygrad.codegen.late',
'tinygrad.engine',
'tinygrad.mixin',
'tinygrad.nn',
'tinygrad.renderer',
'tinygrad.runtime',

View file

@ -517,7 +517,7 @@ class TestUOpStr(unittest.TestCase):
class TestUPatHelpers(unittest.TestCase):
def test_location(self):
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "mixins.py")
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "math.py")
self.assertEqual(shared_spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py")
test_upat = UPat(Ops.CONST, dtypes.bool)
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])

View file

@ -0,0 +1,4 @@
from tinygrad.mixin.math import MathMixin
from tinygrad.mixin.movement import MovementMixin
class OpMixin(MathMixin, MovementMixin): pass

View file

@ -1,11 +1,6 @@
# mixins add syntactic sugar to Tensor and UOp
from typing import TypeAlias, TYPE_CHECKING, Self
from typing import Self
from tinygrad.uop import Ops
from tinygrad.dtype import dtypes, ConstType
from tinygrad.helpers import prod, argfix
if TYPE_CHECKING:
from tinygrad.uop.ops import UOp
sint:TypeAlias = UOp|int
class MathMixin:
# required to implement
@ -175,76 +170,3 @@ class MathMixin:
def exp2(self): return self.alu(Ops.EXP2)
def pow(self, x:Self|ConstType): return self.alu(Ops.POW, self.ufix(x))
def __pow__(self, x:Self|ConstType): return self.pow(x)
class MovementMixin:
# required to implement
def _mop(self, op:Ops, arg) -> Self: raise NotImplementedError
@property
def shape(self) -> tuple["sint", ...]: raise NotImplementedError
# great functions you get!
@property
def ndim(self) -> int:
"""
Returns the number of dimensions in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2], [3, 4]])
print(t.ndim)
```
"""
return len(self.shape)
def numel(self) -> "sint":
"""
Returns the total number of elements in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(t.numel())
```
"""
return prod(self.shape)
def _resolve_dim(self, dim:int, *, extra:bool=False) -> int:
total = self.ndim + int(extra)
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
return dim + total if dim < 0 else dim
def view(self, shape, *args) -> Self:
"""`.view` is an alias for `.reshape`."""
return self.reshape(shape, *args)
def reshape(self, shape, *args) -> Self:
"""
Returns a tensor with the same data as the original tensor but with a different shape.
`shape` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6)
print(t.reshape(2, 3).numpy())
```
"""
# resolve None and args
new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
# resolve -1
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
return self._mop(Ops.RESHAPE, arg=new_shape) if new_shape != self.shape else self
def flatten(self, start_dim=0, end_dim=-1) -> Self:
"""
Flattens the tensor by reshaping it into a one-dimensional tensor.
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(8).reshape(2, 2, 2)
print(t.flatten().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flatten(start_dim=1).numpy())
```
"""
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])

View file

@ -0,0 +1,80 @@
# mixins add syntactic sugar to Tensor and UOp
from typing import TypeAlias, TYPE_CHECKING, Self
from tinygrad.uop import Ops
from tinygrad.helpers import prod, argfix
if TYPE_CHECKING:
from tinygrad.uop.ops import UOp
sint:TypeAlias = UOp|int
class MovementMixin:
# required to implement
def _mop(self, op:Ops, arg) -> Self: raise NotImplementedError
@property
def shape(self) -> tuple["sint", ...]: raise NotImplementedError
# great functions you get!
@property
def ndim(self) -> int:
"""
Returns the number of dimensions in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2], [3, 4]])
print(t.ndim)
```
"""
return len(self.shape)
def numel(self) -> "sint":
"""
Returns the total number of elements in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(t.numel())
```
"""
return prod(self.shape)
def _resolve_dim(self, dim:int, *, extra:bool=False) -> int:
total = self.ndim + int(extra)
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
return dim + total if dim < 0 else dim
def view(self, shape, *args) -> Self:
"""`.view` is an alias for `.reshape`."""
return self.reshape(shape, *args)
def reshape(self, shape, *args) -> Self:
"""
Returns a tensor with the same data as the original tensor but with a different shape.
`shape` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6)
print(t.reshape(2, 3).numpy())
```
"""
# resolve None and args
new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
# resolve -1
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
return self._mop(Ops.RESHAPE, arg=new_shape) if new_shape != self.shape else self
def flatten(self, start_dim=0, end_dim=-1) -> Self:
"""
Flattens the tensor by reshaping it into a one-dimensional tensor.
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(8).reshape(2, 2, 2)
print(t.flatten().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flatten(start_dim=1).numpy())
```
"""
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])

View file

@ -9,7 +9,7 @@ from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_u
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, SPEC
from tinygrad.helpers import suppress_finalizing
from tinygrad.gradient import compute_gradient
from tinygrad.uop.mixins import MathMixin, MovementMixin
from tinygrad.mixin import OpMixin
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Device, Buffer
@ -100,7 +100,7 @@ def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: r
ReductionStr = Literal["mean", "sum", "none"]
class Tensor(MathMixin, MovementMixin):
class Tensor(OpMixin):
"""
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.

View file

@ -4,7 +4,7 @@ import sys, time, functools, itertools, math, operator, hashlib, os, types, pick
from dataclasses import dataclass
from enum import Enum, auto
from tinygrad.uop import Ops, GroupOp
from tinygrad.uop.mixins import MathMixin, MovementMixin
from tinygrad.mixin import OpMixin
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CI
@ -104,7 +104,7 @@ class recursive_property(property):
# NOTE: this should be frozen, but frozen is slower
@dataclass(eq=False, slots=True)
class UOp(MathMixin, MovementMixin, metaclass=UOpMetaClass):
class UOp(OpMixin, metaclass=UOpMetaClass):
op:Ops
dtype:DType = dtypes.void
src:tuple[UOp, ...] = tuple()
@ -867,7 +867,7 @@ def printable(loc:tuple[str, int]) -> str:
try: return lines(loc[0])[loc[1]-1].strip()
except FileNotFoundError: return "<missing>"
class UPat(MathMixin, MovementMixin):
class UPat(OpMixin):
__slots__ = ("op", "dtype", "arg", "name", "src")
def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|None=None,
src:tuple[UPat, ...]|list[UPat]|UPat|None=None, arg:Any=None,