mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
full_like to CreationMixin [PR] (#16702)
This commit is contained in:
parent
97da54b9d6
commit
267af9c601
4 changed files with 28 additions and 32 deletions
|
|
@ -66,7 +66,7 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
|
|||
|
||||
# accumulator (unified: both paths use (TM, TN) with scalar dtypes.float)
|
||||
acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
|
||||
acc = acc.after(acc.store(acc.zeros_like()))
|
||||
acc = acc.after(acc.store(acc.zeros_like(buffer=False)))
|
||||
|
||||
if use_wmma:
|
||||
k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE)
|
||||
|
|
|
|||
|
|
@ -31,12 +31,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
assert self.numel() == 1, "must have one element for item"
|
||||
return self.data()[(0,) * len(self.shape)]
|
||||
|
||||
def _multi_like(self, fxn:Callable[[tuple[sint, ...], str|None], Self]) -> Self:
|
||||
from tinygrad.uop.ops import UOp
|
||||
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
|
||||
if self._uop.axis is None: return self._wrap_uop(fxn(self.shape, None)._uop.shard(self.device, None))
|
||||
return self._wrap_uop(UOp.mstack(*[fxn(self._uop.shard_shape, d)._uop for d in self.device]).multi(self._uop.axis))
|
||||
|
||||
def __getitem__(self, indices) -> Self:
|
||||
"""
|
||||
Retrieves a sub-tensor using indexing.
|
||||
|
|
|
|||
|
|
@ -1,17 +1,24 @@
|
|||
from typing import TYPE_CHECKING, Self
|
||||
from tinygrad.dtype import ConstType, DType, DTypeLike, Invalid, dtypes, to_dtype
|
||||
from typing import TYPE_CHECKING, Callable, Self
|
||||
from tinygrad.dtype import ConstType, DTypeLike, Invalid, dtypes, to_dtype
|
||||
from tinygrad.helpers import argfix
|
||||
from tinygrad.mixin.dtype import DTypeMixin
|
||||
from tinygrad.mixin.movement import MovementMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.uop.ops import sint, UOp
|
||||
|
||||
class CreationMixin(DTypeMixin):
|
||||
class CreationMixin(DTypeMixin, MovementMixin):
|
||||
@staticmethod
|
||||
def const(dtype, b): raise NotImplementedError
|
||||
|
||||
def const_like(self, b: ConstType) -> Self: return self._wrap_uop(self._uop.const_like(b))
|
||||
|
||||
def _multi_like(self, fxn:'Callable[[tuple[sint, ...], str|None], Self]') -> Self:
|
||||
from tinygrad.uop.ops import UOp
|
||||
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
|
||||
if self._uop.axis is None: return self._wrap_uop(fxn(self.shape, None)._uop.shard(self.device, None))
|
||||
return self._wrap_uop(UOp.mstack(*[fxn(self._uop.shard_shape, d)._uop for d in self.device]).multi(self._uop.axis))
|
||||
|
||||
def empty_like(self, dtype: DTypeLike|None=None, device: str|tuple[str, ...]|None=None) -> Self:
|
||||
"""
|
||||
Creates an empty tensor with the same shape as `self`.
|
||||
|
|
@ -55,9 +62,23 @@ class CreationMixin(DTypeMixin):
|
|||
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
|
||||
return val.clone(device=device) if buffer else val
|
||||
|
||||
def full_like(self, fill_value: ConstType, dtype: DType|None=None) -> Self:
|
||||
"""Creates a tensor with the same shape as `self`, filled with the given value."""
|
||||
return self.const_like(fill_value) if dtype is None else self.const_like(fill_value).cast(dtype)
|
||||
def full_like(self, fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
|
||||
"""
|
||||
Creates a tensor with the same shape as `self`, filled with the given value.
|
||||
If `dtype` is not specified, the dtype of `self` is used.
|
||||
|
||||
You can pass in the `device` keyword argument to control device of the tensor.
|
||||
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.ones(2, 3)
|
||||
print(Tensor.full_like(t, 42).numpy())
|
||||
```
|
||||
"""
|
||||
if isinstance(self.device, tuple):
|
||||
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
||||
return self._multi_like(lambda shape, dev: type(self).full(shape, fill_value, dtype=dtype or self.dtype, device=dev, buffer=buffer))
|
||||
return type(self).full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device, buffer=buffer)
|
||||
|
||||
@classmethod
|
||||
def zeros(cls, *shape, **kwargs) -> Self:
|
||||
|
|
|
|||
|
|
@ -503,25 +503,6 @@ class Tensor(RandMixin):
|
|||
high = counter[1:2] - (num >> 32) - (counter[0] < (num & 0xffffffff))
|
||||
return Tensor._device_seeds[device], low.cat(high)
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
||||
def full_like(self, fill_value:ConstType, dtype=None, device=None) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the same shape as `self`, filled with the given value.
|
||||
If `dtype` is not specified, the dtype of `self` is used.
|
||||
|
||||
You can pass in the `device` keyword argument to control device of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.ones(2, 3)
|
||||
print(Tensor.full_like(t, 42).numpy())
|
||||
```
|
||||
"""
|
||||
if isinstance(self.device, tuple):
|
||||
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
||||
return self._multi_like(lambda shape, dev: Tensor.full(shape, fill_value, dtype=dtype or self.dtype, device=dev))
|
||||
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
|
||||
def backward(self, gradient:Tensor|None=None) -> Tensor:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue