const_like and invalids to mixin [PR] (#16690)

* const_like and invalids to mixin [PR]

* empty_like

* einsum

* type
This commit is contained in:
chenyu 2026-06-21 00:02:29 -04:00 committed by GitHub
commit 58ff75272e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 81 additions and 80 deletions

View file

@ -462,6 +462,10 @@ class TestTensorUOpCreation(unittest.TestCase):
self.assertIs(_strip_unique(Tensor.ones(2, 3).uop), _strip_unique(UOp.ones(2, 3)))
def test_invalids(self):
self.assertIs(_strip_unique(Tensor.invalids(2, 3, dtype=dtypes.int8).uop), _strip_unique(UOp.invalids((2, 3), dtype=dtypes.int8)))
def test_empty_like(self):
t = Tensor.empty(2, 3, dtype=dtypes.int8)
self.assertIs(_strip_unique(t.empty_like().uop), _strip_unique(t.uop.empty_like()))
self.assertIs(_strip_unique(t.empty_like(dtype=dtypes.float, device="NULL").uop), _strip_unique(t.uop.empty_like(dtypes.float, "NULL")))
def test_arange(self):
self.assertIs(Tensor.arange(5).uop, UOp.arange(5))
def test_arange_empty(self):

View file

@ -1,13 +1,13 @@
from __future__ import annotations
import functools, itertools
from typing import TYPE_CHECKING, Callable, Self, Sequence, Literal, get_args
import functools, itertools, string
from typing import TYPE_CHECKING, Callable, Self, Sequence, Literal, get_args, cast
from tinygrad.mixin.elementwise import ElementwiseMixin
from tinygrad.mixin.movement import MovementMixin
from tinygrad.mixin.reduce import ReduceMixin
from tinygrad.uop import Ops
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
from tinygrad.dtype import ConstType, DTypeLike, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, prod
from tinygrad.dtype import ConstType, DTypeLike, Invalid, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, argsort, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, merge_dicts, prod
from tinygrad.helpers import resolve_pool_pads, round_up
if TYPE_CHECKING:
@ -20,6 +20,17 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
@staticmethod
def const(dtype, b): raise NotImplementedError
@classmethod
def invalids(cls, *shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Self:
"""
Creates a tensor with the given shape, filled with Invalid.
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
Eventually Tensor.empty will be replaced by this.
"""
return cls.full(argfix(*shape), Invalid, dtype=dtype, device=device)
@classmethod
def full(cls, shape:tuple[sint, ...], fill_value:ConstType|UOp, dtype:DTypeLike|None=None,
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
@ -405,7 +416,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode)
raise NotImplementedError(f"{mode=} is not supported")
def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]:
def _broadcasted(self, y:Self|ConstType|UOp, reverse:bool=False) -> tuple[Self, Self]:
if not isinstance(y, type(self)): y = self.ufix(y)
x, y = (self, y) if not reverse else (y, self)
# ValueError: unsized ptr has shape (-1,) which can't broadcast; RuntimeError: shape mismatch
@ -461,6 +472,47 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
def __matmul__(self, x:Self) -> Self: return self.matmul(x)
def __rmatmul__(self, x:Self) -> Self: return self.matmul(x, True)
@classmethod
def einsum(cls, formula:str, *operands:Self|Sequence[Self], dtype:DTypeLike|None=None) -> Self:
"""
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
y = Tensor([[5, 6], [7, 8]])
print(Tensor.einsum("ij,ij->", x, y).numpy())
```
"""
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
# expand ellipsis to letters, determine output
if "..." in formula:
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
inputs = lhs.split(",")
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
# trace: take diagonal when letter repeats in single input
for i, (s, x) in enumerate(zip(inputs, xs)):
for c in set(s):
while s.count(c) > 1:
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
s = s[:k] + s[k+1:]
inputs[i], xs[i] = s, x
# check sizes and build sorted alphabet
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
alpha = sorted(sz)
# align all tensors to alphabet, multiply, sum non-output, permute to output order
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
for s, x in zip(inputs, xs)]
return xs[0].uprod(*xs[1:]).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))
def gradient(self, *targets:Self, gradient:Self|None=None) -> list[Self]:
"""
Computes the gradient of the targets with respect to self.

View file

@ -1,10 +1,24 @@
from typing import Self
from tinygrad.dtype import ConstType, DType
from typing import TYPE_CHECKING, Self
from tinygrad.dtype import ConstType, DType, DTypeLike
if TYPE_CHECKING:
from tinygrad.uop.ops import UOp
class CreationMixin:
def const_like(self, b: ConstType) -> Self: raise NotImplementedError
@property
def _uop(self) -> 'UOp': raise NotImplementedError
def _wrap_uop(self, u: 'UOp') -> Self: raise NotImplementedError
def cast(self, dtype: DType) -> Self: raise NotImplementedError
def const_like(self, b: ConstType) -> Self: return self._wrap_uop(self._uop.const_like(b))
def empty_like(self, dtype: DTypeLike|None=None, device: str|tuple[str, ...]|None=None) -> Self:
"""
Creates an empty tensor with the same shape as `self`.
If `dtype` is not specified, the dtype of `self` is used.
"""
return self._wrap_uop(self._uop.empty_like(dtype, device))
def full_like(self, fill_value: ConstType, dtype: DType|None=None) -> Self:
"""Creates a tensor with the same shape as `self`, filled with the given value."""
return self.const_like(fill_value) if dtype is None else self.const_like(fill_value).cast(dtype)

View file

@ -15,11 +15,6 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
def alu(self, op: Ops, *src: Self) -> Self:
raise NotImplementedError
@property
def _uop(self) -> 'UOp': raise NotImplementedError
def _wrap_uop(self, u: 'UOp') -> Self: raise NotImplementedError
# great functions you get!
def ufix(self, x: 'Self|ConstType|UOp') -> Self:
return x if isinstance(x, type(self)) else self._wrap_uop(self._uop.ufix(x))

View file

@ -1,8 +1,7 @@
import string
from typing import Self, Sequence, cast
from typing import Self, Sequence
from tinygrad.uop import Ops
from tinygrad.dtype import DTypeLike, dtypes, sum_acc_dtype, to_dtype
from tinygrad.helpers import argfix, argsort, make_tuple, merge_dicts
from tinygrad.helpers import make_tuple
from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.movement import MovementMixin
@ -136,44 +135,3 @@ class ReduceMixin(DTypeMixin, MovementMixin):
```
"""
return self.bool().prod(axis, keepdim)
@classmethod
def einsum(cls, formula:str, *operands:Self|Sequence[Self], dtype:DTypeLike|None=None) -> Self:
"""
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
y = Tensor([[5, 6], [7, 8]])
print(Tensor.einsum("ij,ij->", x, y).numpy())
```
"""
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
# expand ellipsis to letters, determine output
if "..." in formula:
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
inputs = lhs.split(",")
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
# trace: take diagonal when letter repeats in single input
for i, (s, x) in enumerate(zip(inputs, xs)):
for c in set(s):
while s.count(c) > 1:
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
s = s[:k] + s[k+1:]
inputs[i], xs[i] = s, x
# check sizes and build sorted alphabet
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
alpha = sorted(sz)
# align all tensors to alphabet, multiply, sum non-output, permute to output order
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
for s, x in zip(inputs, xs)]
return xs[0].uprod(*xs[1:]).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))

View file

@ -136,24 +136,13 @@ class Tensor(RandMixin):
all_tensors[weakref.ref(ret)] = None
return ret
# alu and const_like are used by the mixins
# alu, _uop, _wrap_uop and const are used by the mixins
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
@property
def _uop(self) -> UOp: return self.uop
def _wrap_uop(self, u:UOp) -> Tensor: return Tensor(u)
def const_like(self, b:ConstType) -> Tensor: return Tensor(self.uop.const_like(b))
@staticmethod
def const(dtype:DType, b:ConstType|UOp) -> Tensor: return Tensor(UOp.const(dtype, b))
@staticmethod
def invalids(*shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Tensor:
"""
Creates a tensor with the given shape, filled with Invalid.
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
Eventually Tensor.empty will be replaced by this.
"""
return Tensor(UOp.invalids(argfix(*shape), dtype, device))
def is_param_(self, is_param:bool=True) -> Tensor:
self.is_param = is_param
@ -464,13 +453,6 @@ class Tensor(RandMixin):
"""
return Tensor(UOp.empty(argfix(*shape), dtype, device))
def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> Tensor:
"""
Creates an empty tensor with the same shape as `self`.
If `dtype` is not specified, the dtype of `self` is used.
"""
return Tensor(self.uop.empty_like(dtype, device))
@staticmethod
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
"""

View file

@ -560,10 +560,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=())
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != () and ret.shape != shape else ret
@staticmethod
def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> UOp:
dt = to_dtype(dtype) if dtype is not None else dtypes.from_py(Invalid)
return UOp.const(dt, Invalid, shape=shape).clone(device=device)
@staticmethod
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs):
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
@staticmethod