more method to CreationMixin [PR] (#16698)

This commit is contained in:
chenyu 2026-06-22 00:01:22 -04:00 committed by GitHub
commit 97da54b9d6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 80 additions and 76 deletions

View file

@ -6,7 +6,7 @@ from tinygrad.mixin.movement import MovementMixin
from tinygrad.mixin.reduce import ReduceMixin
from tinygrad.uop import Ops
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
from tinygrad.dtype import ConstType, DTypeLike, Invalid, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.dtype import ConstType, DTypeLike, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, argsort, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, merge_dicts, prod
from tinygrad.helpers import resolve_pool_pads, round_up
@ -17,9 +17,6 @@ ReductionStr = Literal["mean", "sum", "none"]
class OpMixin(ElementwiseMixin, ReduceMixin):
@staticmethod
def const(dtype, b): raise NotImplementedError
def data(self) -> memoryview: raise NotImplementedError("data requires Tensor realization to host memory")
def item(self) -> PyConst:
@ -40,42 +37,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
if self._uop.axis is None: return self._wrap_uop(fxn(self.shape, None)._uop.shard(self.device, None))
return self._wrap_uop(UOp.mstack(*[fxn(self._uop.shard_shape, d)._uop for d in self.device]).multi(self._uop.axis))
@classmethod
def invalids(cls, *shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Self:
"""
Creates a tensor with the given shape, filled with Invalid.
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
Eventually Tensor.empty will be replaced by this.
"""
return cls.full(argfix(*shape), Invalid, dtype=dtype, device=device)
@classmethod
def full(cls, shape:tuple[sint, ...], fill_value:ConstType|UOp, dtype:DTypeLike|None=None,
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
"""
Creates a tensor with the given shape, filled with the given value.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), 42).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), False).numpy())
```
"""
# TODO: enable this check
# if not buffer: assert device is None, "buffer=False does not support device specification"
from tinygrad.uop.ops import UOp
new_shape = argfix(shape)
dt = to_dtype(dtype) if dtype is not None else None
val = cls.const(dt or (fill_value.dtype if isinstance(fill_value, UOp) else dtypes.from_py(fill_value)), fill_value)
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
return val.clone(device=device) if buffer else val
def __getitem__(self, indices) -> Self:
"""
Retrieves a sub-tensor using indexing.
@ -207,40 +168,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops)))
return (type(self).uprod(*per_dim) if per_dim else type(self).const(dtypes.bool, True)).where(vb, self)
@classmethod
def zeros(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with zeros.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 0.0, **kwargs)
@classmethod
def ones(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with ones.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 1.0, **kwargs)
@classmethod
def arange(cls, start, stop=None, step=1, dtype:DTypeLike|None=None) -> Self:
"""

View file

@ -1,8 +1,15 @@
from typing import Self
from tinygrad.dtype import ConstType, DType, DTypeLike
from typing import TYPE_CHECKING, Self
from tinygrad.dtype import ConstType, DType, DTypeLike, Invalid, dtypes, to_dtype
from tinygrad.helpers import argfix
from tinygrad.mixin.dtype import DTypeMixin
if TYPE_CHECKING:
from tinygrad.uop.ops import sint, UOp
class CreationMixin(DTypeMixin):
@staticmethod
def const(dtype, b): raise NotImplementedError
def const_like(self, b: ConstType) -> Self: return self._wrap_uop(self._uop.const_like(b))
def empty_like(self, dtype: DTypeLike|None=None, device: str|tuple[str, ...]|None=None) -> Self:
@ -12,10 +19,63 @@ class CreationMixin(DTypeMixin):
"""
return self._wrap_uop(self._uop.empty_like(dtype, device))
@classmethod
def invalids(cls, *shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Self:
"""
Creates a tensor with the given shape, filled with Invalid.
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
Eventually Tensor.empty will be replaced by this.
"""
return cls.full(argfix(*shape), Invalid, dtype=dtype, device=device)
@classmethod
def full(cls, shape:'tuple[sint, ...]', fill_value:'ConstType|UOp', dtype:DTypeLike|None=None,
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
"""
Creates a tensor with the given shape, filled with the given value.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), 42).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), False).numpy())
```
"""
# TODO: enable this check
# if not buffer: assert device is None, "buffer=False does not support device specification"
from tinygrad.uop.ops import UOp
new_shape = argfix(shape)
dt = to_dtype(dtype) if dtype is not None else None
val = cls.const(dt or (fill_value.dtype if isinstance(fill_value, UOp) else dtypes.from_py(fill_value)), fill_value)
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
return val.clone(device=device) if buffer else val
def full_like(self, fill_value: ConstType, dtype: DType|None=None) -> Self:
"""Creates a tensor with the same shape as `self`, filled with the given value."""
return self.const_like(fill_value) if dtype is None else self.const_like(fill_value).cast(dtype)
@classmethod
def zeros(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with zeros.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 0.0, **kwargs)
def zeros_like(self, **kwargs) -> Self:
"""
Creates a tensor with the same shape as `self`, filled with zeros.
@ -29,6 +89,23 @@ class CreationMixin(DTypeMixin):
"""
return self.full_like(0, **kwargs)
@classmethod
def ones(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with ones.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
```
"""
return cls.full(argfix(*shape), 1.0, **kwargs)
def ones_like(self, **kwargs) -> Self:
"""
Creates a tensor with the same shape as `self`, filled with ones.