full_like to CreationMixin [PR] (#16702)

This commit is contained in:
chenyu 2026-06-22 09:33:23 -04:00 committed by GitHub
commit 267af9c601
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 28 additions and 32 deletions

View file

@ -66,7 +66,7 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
# accumulator (unified: both paths use (TM, TN) with scalar dtypes.float)
acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
acc = acc.after(acc.store(acc.zeros_like()))
acc = acc.after(acc.store(acc.zeros_like(buffer=False)))
if use_wmma:
k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE)

View file

@ -31,12 +31,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
assert self.numel() == 1, "must have one element for item"
return self.data()[(0,) * len(self.shape)]
def _multi_like(self, fxn:Callable[[tuple[sint, ...], str|None], Self]) -> Self:
from tinygrad.uop.ops import UOp
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self._uop.axis is None: return self._wrap_uop(fxn(self.shape, None)._uop.shard(self.device, None))
return self._wrap_uop(UOp.mstack(*[fxn(self._uop.shard_shape, d)._uop for d in self.device]).multi(self._uop.axis))
def __getitem__(self, indices) -> Self:
"""
Retrieves a sub-tensor using indexing.

View file

@ -1,17 +1,24 @@
from typing import TYPE_CHECKING, Self
from tinygrad.dtype import ConstType, DType, DTypeLike, Invalid, dtypes, to_dtype
from typing import TYPE_CHECKING, Callable, Self
from tinygrad.dtype import ConstType, DTypeLike, Invalid, dtypes, to_dtype
from tinygrad.helpers import argfix
from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.movement import MovementMixin
if TYPE_CHECKING:
from tinygrad.uop.ops import sint, UOp
class CreationMixin(DTypeMixin):
class CreationMixin(DTypeMixin, MovementMixin):
@staticmethod
def const(dtype, b): raise NotImplementedError
def const_like(self, b: ConstType) -> Self: return self._wrap_uop(self._uop.const_like(b))
def _multi_like(self, fxn:'Callable[[tuple[sint, ...], str|None], Self]') -> Self:
from tinygrad.uop.ops import UOp
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self._uop.axis is None: return self._wrap_uop(fxn(self.shape, None)._uop.shard(self.device, None))
return self._wrap_uop(UOp.mstack(*[fxn(self._uop.shard_shape, d)._uop for d in self.device]).multi(self._uop.axis))
def empty_like(self, dtype: DTypeLike|None=None, device: str|tuple[str, ...]|None=None) -> Self:
"""
Creates an empty tensor with the same shape as `self`.
@ -55,9 +62,23 @@ class CreationMixin(DTypeMixin):
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
return val.clone(device=device) if buffer else val
def full_like(self, fill_value: ConstType, dtype: DType|None=None) -> Self:
"""Creates a tensor with the same shape as `self`, filled with the given value."""
return self.const_like(fill_value) if dtype is None else self.const_like(fill_value).cast(dtype)
def full_like(self, fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
"""
Creates a tensor with the same shape as `self`, filled with the given value.
If `dtype` is not specified, the dtype of `self` is used.
You can pass in the `device` keyword argument to control device of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.full_like(t, 42).numpy())
```
"""
if isinstance(self.device, tuple):
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
return self._multi_like(lambda shape, dev: type(self).full(shape, fill_value, dtype=dtype or self.dtype, device=dev, buffer=buffer))
return type(self).full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device, buffer=buffer)
@classmethod
def zeros(cls, *shape, **kwargs) -> Self:

View file

@ -503,25 +503,6 @@ class Tensor(RandMixin):
high = counter[1:2] - (num >> 32) - (counter[0] < (num & 0xffffffff))
return Tensor._device_seeds[device], low.cat(high)
# ***** creation helper functions *****
def full_like(self, fill_value:ConstType, dtype=None, device=None) -> Tensor:
"""
Creates a tensor with the same shape as `self`, filled with the given value.
If `dtype` is not specified, the dtype of `self` is used.
You can pass in the `device` keyword argument to control device of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.full_like(t, 42).numpy())
```
"""
if isinstance(self.device, tuple):
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
return self._multi_like(lambda shape, dev: Tensor.full(shape, fill_value, dtype=dtype or self.dtype, device=dev))
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device)
# ***** toposort and backward pass *****
def backward(self, gradient:Tensor|None=None) -> Tensor: