some rand method to RandMixin [PR] (#16693)

This commit is contained in:
chenyu 2026-06-21 12:16:51 -04:00 committed by GitHub
commit d80a41d559
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 243 additions and 239 deletions

View file

@ -20,6 +20,12 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
@staticmethod
def const(dtype, b): raise NotImplementedError
def _multi_like(self, fxn:Callable[[tuple[sint, ...], str|None], Self]) -> Self:
from tinygrad.uop.ops import UOp
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self._uop.axis is None: return self._wrap_uop(fxn(self.shape, None)._uop.shard(self.device, None))
return self._wrap_uop(UOp.mstack(*[fxn(self._uop.shard_shape, d)._uop for d in self.device]).multi(self._uop.axis))
@classmethod
def invalids(cls, *shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Self:
"""

View file

@ -1,8 +1,10 @@
from __future__ import annotations
from typing import Self
from tinygrad.dtype import DType, dtypes
from tinygrad.helpers import ceildiv, prod
import math
from typing import Self, cast
from tinygrad.dtype import DType, DTypeLike, dtypes, to_dtype
from tinygrad.helpers import all_int, argfix, ceildiv, prod
from tinygrad.mixin import OpMixin
from tinygrad.device import canonicalize_device
class RandMixin(OpMixin):
@ -39,3 +41,235 @@ class RandMixin(OpMixin):
bits = cls.random_bits(key, counter, ceildiv(prod(shape) * dtype.itemsize, 4))
out = cls._bits_to_rand(bits, shape, dtype)
return out.contiguous() if contiguous else out
@staticmethod
def _next_counter(device:str, num:int):
raise NotImplementedError("_next_counter requires the stateful per-device RNG counter, only implemented on Tensor")
@classmethod
def rand(cls, *shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.rand(2, 3)
print(t.numpy())
```
"""
dt = to_dtype(dtype or dtypes.default_float)
if not dtypes.is_float(dt): raise ValueError(f"rand only supports float dtypes, got {dt}")
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
device = cast(str, canonicalize_device(device))
key, counter = cls._next_counter(device, ceildiv(prod(shape) * dt.itemsize, 4))
return cls._rand(key, counter, shape, dt, contiguous=contiguous)
def rand_like(self, **kwargs) -> Self:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.rand_like(t).numpy())
```
"""
if isinstance(self.device, tuple):
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
dtype = kwargs.pop("dtype", self.dtype)
return self._multi_like(lambda shape, dev: type(self).rand(*shape, dtype=dtype, device=dev, **kwargs))
return type(self).rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
def randn_like(self, dtype:DTypeLike|None=None, **kwargs) -> Self:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a normal distribution with mean 0 and variance 1.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.randn_like(t).numpy())
```
"""
src = self.stack(self).rand_like(**{**kwargs, "dtype": dtypes.float32})
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(to_dtype(dtype or self.dtype))
@classmethod
def randn(cls, *shape, dtype:DTypeLike|None=None, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randn(2, 3).numpy())
```
"""
return cls.empty(*shape, **kwargs).randn_like(dtype=dtype) # type: ignore[attr-defined]
@classmethod
def randint(cls, *shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
Requires `low < high`. If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randint(2, 3, low=5, high=10).numpy())
```
"""
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
if low >= high: raise ValueError(f"Tensor.randint requires low < high, got {low=}, {high=}")
return cls.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
@classmethod
def normal(cls, *shape, mean=0.0, std=1.0, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
Requires `std >= 0`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
```
"""
if std < 0: raise ValueError(f"Tensor.normal requires std >= 0, got {std=}")
return std * cls.randn(*shape, **kwargs) + mean
@classmethod
def uniform(cls, *shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
Requires `low < high`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
```
"""
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if low >= high: raise ValueError(f"Tensor.uniform requires low < high, got {low=}, {high=}")
return ((high-low) * cls.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low
@classmethod
def scaled_uniform(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.scaled_uniform(2, 3).numpy())
```
"""
return cls.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
@classmethod
def glorot_uniform(cls, *shape, **kwargs) -> Self:
"""
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.glorot_uniform(2, 3).numpy())
```
"""
bound = (6 / (argfix(*shape)[0]+prod(argfix(*shape)[1:]))) ** 0.5
return cls.uniform(*shape, low=-bound, high=bound, **kwargs)
@classmethod
def kaiming_uniform(cls, *shape, a:float = 0.01, **kwargs) -> Self:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_uniform(2, 3).numpy())
```
"""
bound = (6 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return cls.uniform(*shape, low=-bound, high=bound, **kwargs)
@classmethod
def kaiming_normal(cls, *shape, a:float = 0.01, **kwargs) -> Self:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_normal(2, 3).numpy())
```
"""
std = (2 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return cls.normal(*shape, mean=0.0, std=std, **kwargs)
@classmethod
def randperm(cls, n:int, device=None, dtype=dtypes.int32, **kwargs) -> Self:
"""
Returns a tensor with a random permutation of integers from `0` to `n-1`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randperm(6).numpy())
```
"""
return cls.rand(n, device=device, **kwargs).argsort().cast(dtype)
def multinomial(self, num_samples:int = 1, replacement:bool = False) -> Self:
"""
Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(20, replacement=True).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(3, replacement=False).numpy())
```
"""
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
weight = self.unsqueeze(0) if self.ndim == 1 else self
assert replacement or num_samples <= weight.shape[1], "no replacement samples must not exceed population size"
if replacement or num_samples == 1:
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = type(self).rand(num_samples, cdf.shape[0], 1).to(self.device) # type: ignore[attr-defined]
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
else:
# Efraimidis-Spirakis
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)

View file

@ -515,35 +515,8 @@ class Tensor(RandMixin):
high = counter[1:2] - (num >> 32) - (counter[0] < (num & 0xffffffff))
return Tensor._device_seeds[device], low.cat(high)
@staticmethod
def rand(*shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.rand(2, 3)
print(t.numpy())
```
"""
dt = to_dtype(dtype or dtypes.default_float)
if not dtypes.is_float(dt): raise ValueError(f"rand only supports float dtypes, got {dt}")
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
device = cast(str, canonicalize_device(device))
key, counter = Tensor._next_counter(device, ceildiv(prod(shape) * dt.itemsize, 4))
return Tensor._rand(key, counter, shape, dt, contiguous=contiguous)
# ***** creation helper functions *****
def _multi_like(self, fxn:Callable[[tuple[sint, ...], str|None], Tensor]) -> Tensor:
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self.uop.axis is None: return fxn(self.shape, None).shard(self.device)
stacked = UOp.mstack(*[fxn(self.uop.shard_shape, d).uop for d in self.device])
return Tensor(stacked.multi(self.uop.axis))
def full_like(self, fill_value:ConstType, dtype=None, device=None) -> Tensor:
"""
Creates a tensor with the same shape as `self`, filled with the given value.
@ -561,215 +534,6 @@ class Tensor(RandMixin):
return self._multi_like(lambda shape, dev: Tensor.full(shape, fill_value, dtype=dtype or self.dtype, device=dev))
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device)
def rand_like(self, **kwargs) -> Tensor:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.rand_like(t).numpy())
```
"""
if isinstance(self.device, tuple):
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
dtype = kwargs.pop("dtype", self.dtype)
return self._multi_like(lambda shape, dev: Tensor.rand(*shape, dtype=dtype, device=dev, **kwargs))
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
# ***** random functions *****
def randn_like(self, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a normal distribution with mean 0 and variance 1.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.randn_like(t).numpy())
```
"""
src = self.stack(self).rand_like(**{**kwargs, "dtype": dtypes.float32})
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or self.dtype)
@staticmethod
def randn(*shape, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randn(2, 3).numpy())
```
"""
return Tensor.empty(*shape, **kwargs).randn_like(dtype=dtype)
@staticmethod
def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
Requires `low < high`. If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randint(2, 3, low=5, high=10).numpy())
```
"""
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
if low >= high: raise ValueError(f"Tensor.randint requires low < high, got {low=}, {high=}")
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
Requires `std >= 0`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
```
"""
if std < 0: raise ValueError(f"Tensor.normal requires std >= 0, got {std=}")
return std * Tensor.randn(*shape, **kwargs) + mean
@staticmethod
def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
Requires `low < high`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
```
"""
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if low >= high: raise ValueError(f"Tensor.uniform requires low < high, got {low=}, {high=}")
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.scaled_uniform(2, 3).numpy())
```
"""
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor:
"""
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.glorot_uniform(2, 3).numpy())
```
"""
bound = (6 / (argfix(*shape)[0]+prod(argfix(*shape)[1:]))) ** 0.5
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_uniform(2, 3).numpy())
```
"""
bound = (6 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
@staticmethod
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_normal(2, 3).numpy())
```
"""
std = (2 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
@staticmethod
def randperm(n:int, device=None, dtype=dtypes.int32, **kwargs) -> Tensor:
"""
Returns a tensor with a random permutation of integers from `0` to `n-1`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randperm(6).numpy())
```
"""
return Tensor.rand(n, device=device, **kwargs).argsort().cast(dtype)
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
"""
Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(20, replacement=True).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(3, replacement=False).numpy())
```
"""
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
weight = self.unsqueeze(0) if self.ndim == 1 else self
assert replacement or num_samples <= weight.shape[1], "no replacement samples must not exceed population size"
if replacement or num_samples == 1:
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device)
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
else:
# EfraimidisSpirakis
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
# ***** toposort and backward pass *****
def backward(self, gradient:Tensor|None=None) -> Tensor: