Merge branch 'master' into codegen2

This commit is contained in:
George Hotz 2026-06-21 19:18:24 -07:00 committed by GitHub
commit d319b5f614
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 493 additions and 492 deletions

View file

@ -2674,14 +2674,14 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
# 1x32 block scaling along the last axis
rows, K = x.shape
*batch, K = x.shape
scale_K, k_iters = K // 32, K // 128
amax = x.detach().float().reshape(rows, scale_K, 32).abs().max(axis=-1)
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, 32).reshape(rows, K)
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K)
x_scaled = x.float() * qscale
x_clamped = x_scaled + (x_scaled.detach().clamp(-448.0, 448.0) - x_scaled.detach()) # STE
return x_clamped.cast(FP8_DTYPE), e8, mx_pack(e8)
return x_clamped.cast(FP8_DTYPE), e8, (mx_pack(e8) if len(batch) == 1 else None)
def mx_pack(e8:Tensor) -> Tensor:
rows, scale_K = e8.shape

View file

@ -317,6 +317,19 @@ class TestTensorUOpScatterReduce(unittest.TestCase):
def test_mean_exclude_self(self):
self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="mean", include_self=False)
class TestTensorUOpMaskedSelect(unittest.TestCase):
# only the fixed-size path is pure
def _check(self, t, mask, **kw):
self.assertIs(t.masked_select(mask, **kw).uop, t.uop.masked_select(mask.uop, **kw))
def test_masked_select_1d(self): self._check(_t(6), Tensor([True, False, True, False, True, False]), size=4)
def test_masked_select_2d(self):
self._check(_t(3, 3), Tensor([[True, False, True], [False, True, False], [False, False, True]]), size=6, fill_value=-1)
class TestTensorUOpNonzero(unittest.TestCase):
def _check(self, t, **kw): self.assertIs(t.nonzero(**kw).uop, t.uop.nonzero(**kw))
def test_nonzero_1d(self): self._check(_t(5), size=3)
def test_nonzero_2d(self): self._check(_t(2, 3), size=4)
class TestTensorUOpPool(unittest.TestCase):
def test_avg_pool2d(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d())
def test_avg_pool2d_padding(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d(padding=1))
@ -462,6 +475,10 @@ class TestTensorUOpCreation(unittest.TestCase):
self.assertIs(_strip_unique(Tensor.ones(2, 3).uop), _strip_unique(UOp.ones(2, 3)))
def test_invalids(self):
self.assertIs(_strip_unique(Tensor.invalids(2, 3, dtype=dtypes.int8).uop), _strip_unique(UOp.invalids((2, 3), dtype=dtypes.int8)))
def test_empty_like(self):
t = Tensor.empty(2, 3, dtype=dtypes.int8)
self.assertIs(_strip_unique(t.empty_like().uop), _strip_unique(t.uop.empty_like()))
self.assertIs(_strip_unique(t.empty_like(dtype=dtypes.float, device="NULL").uop), _strip_unique(t.uop.empty_like(dtypes.float, "NULL")))
def test_arange(self):
self.assertIs(Tensor.arange(5).uop, UOp.arange(5))
def test_arange_empty(self):

View file

@ -13,35 +13,26 @@ class TestWinograd(unittest.TestCase):
def test_forward_kernels(self):
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
out = Tensor.conv2d(x,w)
self.assertEqual(len(out.schedule_linear().src), 2)
self.assertEqual(len(out.schedule_linear().src), 4)
def test_backward_kernels(self):
x,w = Tensor.empty(1,4,9,9).realize(), Tensor.empty(4,4,3,3).realize()
out = Tensor.conv2d(x,w, padding=1)
out.mean().backward()
backward_schedule = x.grad.schedule_linear(w.grad)
self.assertEqual(len(backward_schedule.src), 2)
self.assertEqual(len(backward_schedule.src), 4)
@unittest.skip("this requires optimizations")
def test_counters(self):
IC, OC, X, Y = 4,4,9,9
x,w = Tensor.rand(1,IC,Y,X).realize(), Tensor.rand(OC,IC,3,3).realize()
IC, OC, H = 64, 64, 28
x,w = Tensor.empty(1,IC,H,H,device="NULL").realize(), Tensor.empty(OC,IC,3,3,device="NULL").realize()
GlobalCounters.reset()
with Context(WINO=1):
Tensor.conv2d(x,w).realize()
ops_wino, mem_wino = GlobalCounters.global_ops, GlobalCounters.global_mem
with Context(NOOPT=0, WINO=1): Tensor.conv2d(x,w).realize()
ops_wino = GlobalCounters.global_ops
GlobalCounters.reset()
with Context(WINO=0):
Tensor.conv2d(x,w).realize()
ops_normal, mem_normal = GlobalCounters.global_ops, GlobalCounters.global_mem
ops_ratio, mem_ratio = ops_wino/ops_normal, mem_wino/mem_normal
print(f"ops: normal {ops_normal:9d} wino {ops_wino:9d} ratio {ops_ratio:.2f}")
print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}")
# TODO: what's optimal on this?
self.assertLess(ops_ratio, 4.3)
self.assertLess(mem_ratio, 4)
with Context(NOOPT=0, WINO=0): Tensor.conv2d(x,w).realize()
ops_normal = GlobalCounters.global_ops
print(f"ops: normal {ops_normal} wino {ops_wino} ratio {ops_wino/ops_normal:.2f}")
self.assertLess(ops_wino/ops_normal, 0.6)
def test_dtype(self):
IC, OC, X, Y = 4,4,9,9

View file

@ -222,7 +222,7 @@ class TestCallSchedule(unittest.TestCase):
# find the FUNCTION nodes
c0 = next(u for u in r0.uop.toposort() if u.op is Ops.FUNCTION)
c1 = next(u for u in r1.uop.toposort() if u.op is Ops.FUNCTION)
# the function bodies (src[0]) should have identical keys — unique consts must not leak through
# the function bodies (src[0]) should have identical keys
self.assertEqual(c0.src[0].key, c1.src[0].key)
def test_precompile_symbolic_2d(self):

View file

@ -231,8 +231,7 @@ def _prepare_jit_inputs(args, kwargs):
it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else []
tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)]
def get_input_uops() -> list[UOp]: return flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
# TODO: drop the CONST branch once all CONST are deviceless
if any(u.device is None or u.base.op is Ops.CONST for u in get_input_uops()): raise JitError("JIT inputs must be real buffers; use .clone()")
if any(u.device is None for u in get_input_uops()): raise JitError("JIT inputs must be real buffers; use .clone()")
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
input_uops = get_input_uops()
# collect buffer UOps (including MultiBuffer)

View file

@ -1,13 +1,13 @@
from __future__ import annotations
import functools, itertools
from typing import TYPE_CHECKING, Callable, Self, Sequence, Literal, get_args
import functools, itertools, string
from typing import TYPE_CHECKING, Callable, Self, Sequence, Literal, get_args, cast
from tinygrad.mixin.elementwise import ElementwiseMixin
from tinygrad.mixin.movement import MovementMixin
from tinygrad.mixin.reduce import ReduceMixin
from tinygrad.uop import Ops
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
from tinygrad.dtype import ConstType, DTypeLike, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, prod
from tinygrad.dtype import ConstType, DTypeLike, Invalid, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, argsort, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, merge_dicts, prod
from tinygrad.helpers import resolve_pool_pads, round_up
if TYPE_CHECKING:
@ -20,6 +20,37 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
@staticmethod
def const(dtype, b): raise NotImplementedError
def data(self) -> memoryview: raise NotImplementedError("data requires Tensor realization to host memory")
def item(self) -> PyConst:
"""
Returns the value of this tensor as a standard Python number.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor(42)
print(t.item())
```
"""
assert self.numel() == 1, "must have one element for item"
return self.data()[(0,) * len(self.shape)]
def _multi_like(self, fxn:Callable[[tuple[sint, ...], str|None], Self]) -> Self:
from tinygrad.uop.ops import UOp
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self._uop.axis is None: return self._wrap_uop(fxn(self.shape, None)._uop.shard(self.device, None))
return self._wrap_uop(UOp.mstack(*[fxn(self._uop.shard_shape, d)._uop for d in self.device]).multi(self._uop.axis))
@classmethod
def invalids(cls, *shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Self:
"""
Creates a tensor with the given shape, filled with Invalid.
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
Eventually Tensor.empty will be replaced by this.
"""
return cls.full(argfix(*shape), Invalid, dtype=dtype, device=device)
@classmethod
def full(cls, shape:tuple[sint, ...], fill_value:ConstType|UOp, dtype:DTypeLike|None=None,
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
@ -45,7 +76,45 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
return val.clone(device=device) if buffer else val
def __getitem__(self, indices) -> Self: return self._getitem(indices)
def __getitem__(self, indices) -> Self:
"""
Retrieves a sub-tensor using indexing.
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
Examples:
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(12).reshape(3, 4)
print(t.numpy())
```
- Int Indexing: Select an element or sub-tensor using integers for each dimension.
```python exec="true" source="above" session="tensor" result="python"
print(t[1, 2].numpy())
```
- Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
```python exec="true" source="above" session="tensor" result="python"
print(t[0:2, ::2].numpy())
```
- Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
```python exec="true" source="above" session="tensor" result="python"
print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
```
- `None` Indexing: Add a new dimension to the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(t[:, None].shape)
```
NOTE: Out-of-bounds indexing results in a value of `0`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t[Tensor([4, 3, 2])].numpy())
```
"""
return self._getitem(indices)
def _getitem(self, indices, v=None) -> Self:
from tinygrad.uop.ops import UOp
@ -367,7 +436,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode)
raise NotImplementedError(f"{mode=} is not supported")
def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]:
def _broadcasted(self, y:Self|ConstType|UOp, reverse:bool=False) -> tuple[Self, Self]:
if not isinstance(y, type(self)): y = self.ufix(y)
x, y = (self, y) if not reverse else (y, self)
# ValueError: unsized ptr has shape (-1,) which can't broadcast; RuntimeError: shape mismatch
@ -423,6 +492,47 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
def __matmul__(self, x:Self) -> Self: return self.matmul(x)
def __rmatmul__(self, x:Self) -> Self: return self.matmul(x, True)
@classmethod
def einsum(cls, formula:str, *operands:Self|Sequence[Self], dtype:DTypeLike|None=None) -> Self:
"""
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
y = Tensor([[5, 6], [7, 8]])
print(Tensor.einsum("ij,ij->", x, y).numpy())
```
"""
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
# expand ellipsis to letters, determine output
if "..." in formula:
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
inputs = lhs.split(",")
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
# trace: take diagonal when letter repeats in single input
for i, (s, x) in enumerate(zip(inputs, xs)):
for c in set(s):
while s.count(c) > 1:
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
s = s[:k] + s[k+1:]
inputs[i], xs[i] = s, x
# check sizes and build sorted alphabet
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
alpha = sorted(sz)
# align all tensors to alphabet, multiply, sum non-output, permute to output order
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
for s, x in zip(inputs, xs)]
return xs[0].uprod(*xs[1:]).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))
def gradient(self, *targets:Self, gradient:Self|None=None) -> list[Self]:
"""
Computes the gradient of the targets with respect to self.
@ -1147,6 +1257,68 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
# select from values for each True element in mask else select from self
return mask.where(values, self)
def masked_select(self, mask, size:int|None=None, fill_value:ConstType=0):
"""
Selects elements from `self` based on the boolean `mask`.
With `size=None` (default), output length equals the number of `True` values (not jittable).
With `size=N`, output length is `N`, padded with `fill_value` or truncated (jittable).
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
mask = Tensor([[True, False, True], [False, True, False], [False, False, True]])
print(t.numpy())
print(mask.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.masked_select(mask).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.masked_select(mask, size=6, fill_value=-1).numpy())
```
"""
if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
mask_cumsum = mask.cumsum()
if size is None:
counts = type(self).zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, buffer=False)
return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()]
counts = type(self).zeros(size, dtype=dtypes.int32, buffer=False).scatter(0, mask_cumsum, 1, reduce='add')
return (type(self).arange(size) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype)
def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Self:
"""
Returns the indices of the elements that are non-zero.
With `size=None` (default), output shape is `(n_nonzero, ndim)` (not jittable).
With `size=N`, output shape is `(N, ndim)`, padded with `fill_value` or truncated (jittable).
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 0, 2, 0, 3])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 0], [0, 2]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero(size=3, fill_value=-1).numpy())
```
"""
if self.ndim == 0:
return type(self).zeros(size if size is not None else int(self.ne(0).item()), 0, dtype=dtypes.int32, device=self.device)
mask = self.ne(0).flatten()
indices = type(self).stack(*[type(self).arange(s).reshape(*[1]*i, s, *[1]*(self.ndim-i-1)).expand(self.shape).flatten()
for i, s in enumerate(self.shape)], dim=-1)
return indices.masked_select(mask.unsqueeze(-1).expand(*mask.shape, self.ndim),
size=size*self.ndim if size is not None else None, fill_value=fill_value).reshape(-1, self.ndim)
# ***** functional nn ops *****
def sequential(self, ll:list[Callable[[Self], Self]]) -> Self:

View file

@ -1,9 +1,16 @@
from typing import Self
from tinygrad.dtype import ConstType, DType
from tinygrad.dtype import ConstType, DType, DTypeLike
from tinygrad.mixin.dtype import DTypeMixin
class CreationMixin:
def const_like(self, b: ConstType) -> Self: raise NotImplementedError
def cast(self, dtype: DType) -> Self: raise NotImplementedError
class CreationMixin(DTypeMixin):
def const_like(self, b: ConstType) -> Self: return self._wrap_uop(self._uop.const_like(b))
def empty_like(self, dtype: DTypeLike|None=None, device: str|tuple[str, ...]|None=None) -> Self:
"""
Creates an empty tensor with the same shape as `self`.
If `dtype` is not specified, the dtype of `self` is used.
"""
return self._wrap_uop(self._uop.empty_like(dtype, device))
def full_like(self, fill_value: ConstType, dtype: DType|None=None) -> Self:
"""Creates a tensor with the same shape as `self`, filled with the given value."""

View file

@ -1,13 +1,36 @@
from typing import Self
from tinygrad.dtype import DType, dtypes
from typing import TYPE_CHECKING, Self
from tinygrad.dtype import DType, DTypeLike, dtypes, to_dtype
if TYPE_CHECKING:
from tinygrad.uop.ops import UOp
class DTypeMixin:
@property
def dtype(self) -> DType: raise NotImplementedError
@property
def _uop(self) -> 'UOp': raise NotImplementedError
def _wrap_uop(self, u:'UOp') -> Self: raise NotImplementedError
def cast(self, dtype:DType) -> Self: raise NotImplementedError
def cast(self, dtype:DTypeLike) -> Self:
"""
Casts `self` to the given `dtype`.
def bitcast(self, dtype:DType) -> Self: raise NotImplementedError
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.uint8)
print(t.dtype, t.numpy())
```
"""
return self if self.dtype == (dt:=to_dtype(dtype)) else self._wrap_uop(self._uop.cast(dt))
def bitcast(self, dtype:DTypeLike) -> Self: raise NotImplementedError
def element_size(self) -> int:
"""

View file

@ -3,23 +3,17 @@ from typing import TYPE_CHECKING, Literal, Self
from tinygrad.uop import Ops
from tinygrad.dtype import dtypes, ConstType, PyConst, least_upper_dtype, least_upper_float
from tinygrad.helpers import argfix, polyN
from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.creation import CreationMixin
if TYPE_CHECKING:
from tinygrad.uop.ops import UOp
class ElementwiseMixin(DTypeMixin, CreationMixin):
class ElementwiseMixin(CreationMixin):
# required to implement
def alu(self, op: Ops, *src: Self) -> Self:
raise NotImplementedError
@property
def _uop(self) -> 'UOp': raise NotImplementedError
def _wrap_uop(self, u: 'UOp') -> Self: raise NotImplementedError
# great functions you get!
def ufix(self, x: 'Self|ConstType|UOp') -> Self:
return x if isinstance(x, type(self)) else self._wrap_uop(self._uop.ufix(x))

View file

@ -33,7 +33,7 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
params = {x.arg.slot:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
grad_args = ctx.src
root_grad = UOp(Ops.TUPLE, src=tuple(UOp(Ops.NOOP) if g.op is Ops.NOOP else
g if g.base.op is Ops.CONST and g.device is None else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
g if g.base.op is Ops.CONST else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
grads = compute_gradient(fxn, root_grad, set(params.values()))
# for precompiled calls, substitute forward outputs with params so intermediates aren't recomputed
fwd_subs = {src: src.param_like(len(args)+len(grad_args)+i) for i, src in enumerate(fxn.src)} if k.arg.precompile else {}

View file

@ -1,8 +1,10 @@
from __future__ import annotations
from typing import Self
from tinygrad.dtype import DType, dtypes
from tinygrad.helpers import ceildiv, prod
import math
from typing import Self, cast
from tinygrad.dtype import DType, DTypeLike, dtypes, to_dtype
from tinygrad.helpers import all_int, argfix, ceildiv, prod
from tinygrad.mixin import OpMixin
from tinygrad.device import canonicalize_device
class RandMixin(OpMixin):
@ -39,3 +41,235 @@ class RandMixin(OpMixin):
bits = cls.random_bits(key, counter, ceildiv(prod(shape) * dtype.itemsize, 4))
out = cls._bits_to_rand(bits, shape, dtype)
return out.contiguous() if contiguous else out
@staticmethod
def _next_counter(device:str, num:int):
raise NotImplementedError("_next_counter requires the stateful per-device RNG counter, only implemented on Tensor")
@classmethod
def rand(cls, *shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.rand(2, 3)
print(t.numpy())
```
"""
dt = to_dtype(dtype or dtypes.default_float)
if not dtypes.is_float(dt): raise ValueError(f"rand only supports float dtypes, got {dt}")
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
device = cast(str, canonicalize_device(device))
key, counter = cls._next_counter(device, ceildiv(prod(shape) * dt.itemsize, 4))
return cls._rand(key, counter, shape, dt, contiguous=contiguous)
def rand_like(self, **kwargs) -> Self:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.rand_like(t).numpy())
```
"""
if isinstance(self.device, tuple):
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
dtype = kwargs.pop("dtype", self.dtype)
return self._multi_like(lambda shape, dev: type(self).rand(*shape, dtype=dtype, device=dev, **kwargs))
return type(self).rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
def randn_like(self, dtype:DTypeLike|None=None, **kwargs) -> Self:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a normal distribution with mean 0 and variance 1.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.randn_like(t).numpy())
```
"""
src = self.stack(self).rand_like(**{**kwargs, "dtype": dtypes.float32})
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(to_dtype(dtype or self.dtype))
@classmethod
def randn(cls, *shape, dtype:DTypeLike|None=None, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randn(2, 3).numpy())
```
"""
return cls.empty(*shape, **kwargs).randn_like(dtype=dtype) # type: ignore[attr-defined]
@classmethod
def randint(cls, *shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
Requires `low < high`. If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randint(2, 3, low=5, high=10).numpy())
```
"""
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
if low >= high: raise ValueError(f"Tensor.randint requires low < high, got {low=}, {high=}")
return cls.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
@classmethod
def normal(cls, *shape, mean=0.0, std=1.0, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
Requires `std >= 0`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
```
"""
if std < 0: raise ValueError(f"Tensor.normal requires std >= 0, got {std=}")
return std * cls.randn(*shape, **kwargs) + mean
@classmethod
def uniform(cls, *shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
Requires `low < high`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
```
"""
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if low >= high: raise ValueError(f"Tensor.uniform requires low < high, got {low=}, {high=}")
return ((high-low) * cls.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low
@classmethod
def scaled_uniform(cls, *shape, **kwargs) -> Self:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.scaled_uniform(2, 3).numpy())
```
"""
return cls.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
@classmethod
def glorot_uniform(cls, *shape, **kwargs) -> Self:
"""
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.glorot_uniform(2, 3).numpy())
```
"""
bound = (6 / (argfix(*shape)[0]+prod(argfix(*shape)[1:]))) ** 0.5
return cls.uniform(*shape, low=-bound, high=bound, **kwargs)
@classmethod
def kaiming_uniform(cls, *shape, a:float = 0.01, **kwargs) -> Self:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_uniform(2, 3).numpy())
```
"""
bound = (6 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return cls.uniform(*shape, low=-bound, high=bound, **kwargs)
@classmethod
def kaiming_normal(cls, *shape, a:float = 0.01, **kwargs) -> Self:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_normal(2, 3).numpy())
```
"""
std = (2 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return cls.normal(*shape, mean=0.0, std=std, **kwargs)
@classmethod
def randperm(cls, n:int, device=None, dtype=dtypes.int32, **kwargs) -> Self:
"""
Returns a tensor with a random permutation of integers from `0` to `n-1`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randperm(6).numpy())
```
"""
return cls.rand(n, device=device, **kwargs).argsort().cast(dtype)
def multinomial(self, num_samples:int = 1, replacement:bool = False) -> Self:
"""
Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(20, replacement=True).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(3, replacement=False).numpy())
```
"""
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
weight = self.unsqueeze(0) if self.ndim == 1 else self
assert replacement or num_samples <= weight.shape[1], "no replacement samples must not exceed population size"
if replacement or num_samples == 1:
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = type(self).rand(num_samples, cdf.shape[0], 1).to(self.device) # type: ignore[attr-defined]
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
else:
# Efraimidis-Spirakis
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)

View file

@ -1,8 +1,7 @@
import string
from typing import Self, Sequence, cast
from typing import Self, Sequence
from tinygrad.uop import Ops
from tinygrad.dtype import DTypeLike, dtypes, sum_acc_dtype, to_dtype
from tinygrad.helpers import argfix, argsort, make_tuple, merge_dicts
from tinygrad.helpers import make_tuple
from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.movement import MovementMixin
@ -136,44 +135,3 @@ class ReduceMixin(DTypeMixin, MovementMixin):
```
"""
return self.bool().prod(axis, keepdim)
@classmethod
def einsum(cls, formula:str, *operands:Self|Sequence[Self], dtype:DTypeLike|None=None) -> Self:
"""
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
y = Tensor([[5, 6], [7, 8]])
print(Tensor.einsum("ij,ij->", x, y).numpy())
```
"""
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
# expand ellipsis to letters, determine output
if "..." in formula:
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
inputs = lhs.split(",")
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
# trace: take diagonal when letter repeats in single input
for i, (s, x) in enumerate(zip(inputs, xs)):
for c in set(s):
while s.count(c) > 1:
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
s = s[:k] + s[k+1:]
inputs[i], xs[i] = s, x
# check sizes and build sorted alphabet
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
alpha = sorted(sz)
# align all tensors to alphabet, multiply, sum non-output, permute to output order
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
for s, x in zip(inputs, xs)]
return xs[0].uprod(*xs[1:]).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))

View file

@ -543,9 +543,6 @@ to_define_global = PatternMatcher([
# remove device from local BUFFERIZE
(UPat(Ops.STAGE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),
# remove UNIQUE/DEVICE to dedup CONST
(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
# renumber the ranges starting with 0 so that kernel deduping works
(UPat(Ops.RANGE, name="r"), renumber_range),
])

View file

@ -136,24 +136,13 @@ class Tensor(RandMixin):
all_tensors[weakref.ref(ret)] = None
return ret
# alu and const_like are used by the mixins
# alu, _uop, _wrap_uop and const are used by the mixins
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
@property
def _uop(self) -> UOp: return self.uop
def _wrap_uop(self, u:UOp) -> Tensor: return Tensor(u)
def const_like(self, b:ConstType) -> Tensor: return Tensor(self.uop.const_like(b))
@staticmethod
def const(dtype:DType, b:ConstType|UOp) -> Tensor: return Tensor(UOp.const(dtype, b))
@staticmethod
def invalids(*shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Tensor:
"""
Creates a tensor with the given shape, filled with Invalid.
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
Eventually Tensor.empty will be replaced by this.
"""
return Tensor(UOp.invalids(argfix(*shape), dtype, device))
def is_param_(self, is_param:bool=True) -> Tensor:
self.is_param = is_param
@ -292,18 +281,6 @@ class Tensor(RandMixin):
assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
return self._buffer().as_memoryview().cast(self.dtype.base.fmt, self.shape)
def item(self) -> PyConst:
"""
Returns the value of this tensor as a standard Python number.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor(42)
print(t.item())
```
"""
assert self.numel() == 1, "must have one element for item"
return self.data()[(0,) * len(self.shape)]
# NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions)
def tolist(self) -> PyConst|list[Any]:
"""
@ -464,13 +441,6 @@ class Tensor(RandMixin):
"""
return Tensor(UOp.empty(argfix(*shape), dtype, device))
def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> Tensor:
"""
Creates an empty tensor with the same shape as `self`.
If `dtype` is not specified, the dtype of `self` is used.
"""
return Tensor(self.uop.empty_like(dtype, device))
@staticmethod
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
"""
@ -533,35 +503,8 @@ class Tensor(RandMixin):
high = counter[1:2] - (num >> 32) - (counter[0] < (num & 0xffffffff))
return Tensor._device_seeds[device], low.cat(high)
@staticmethod
def rand(*shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.rand(2, 3)
print(t.numpy())
```
"""
dt = to_dtype(dtype or dtypes.default_float)
if not dtypes.is_float(dt): raise ValueError(f"rand only supports float dtypes, got {dt}")
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
device = cast(str, canonicalize_device(device))
key, counter = Tensor._next_counter(device, ceildiv(prod(shape) * dt.itemsize, 4))
return Tensor._rand(key, counter, shape, dt, contiguous=contiguous)
# ***** creation helper functions *****
def _multi_like(self, fxn:Callable[[tuple[sint, ...], str|None], Tensor]) -> Tensor:
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self.uop.axis is None: return fxn(self.shape, None).shard(self.device)
stacked = UOp.mstack(*[fxn(self.uop.shard_shape, d).uop for d in self.device])
return Tensor(stacked.multi(self.uop.axis))
def full_like(self, fill_value:ConstType, dtype=None, device=None) -> Tensor:
"""
Creates a tensor with the same shape as `self`, filled with the given value.
@ -579,215 +522,6 @@ class Tensor(RandMixin):
return self._multi_like(lambda shape, dev: Tensor.full(shape, fill_value, dtype=dtype or self.dtype, device=dev))
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device)
def rand_like(self, **kwargs) -> Tensor:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.rand_like(t).numpy())
```
"""
if isinstance(self.device, tuple):
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
dtype = kwargs.pop("dtype", self.dtype)
return self._multi_like(lambda shape, dev: Tensor.rand(*shape, dtype=dtype, device=dev, **kwargs))
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
# ***** random functions *****
def randn_like(self, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a normal distribution with mean 0 and variance 1.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.randn_like(t).numpy())
```
"""
src = self.stack(self).rand_like(**{**kwargs, "dtype": dtypes.float32})
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or self.dtype)
@staticmethod
def randn(*shape, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randn(2, 3).numpy())
```
"""
return Tensor.empty(*shape, **kwargs).randn_like(dtype=dtype)
@staticmethod
def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
Requires `low < high`. If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randint(2, 3, low=5, high=10).numpy())
```
"""
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
if low >= high: raise ValueError(f"Tensor.randint requires low < high, got {low=}, {high=}")
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
Requires `std >= 0`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
```
"""
if std < 0: raise ValueError(f"Tensor.normal requires std >= 0, got {std=}")
return std * Tensor.randn(*shape, **kwargs) + mean
@staticmethod
def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
Requires `low < high`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
```
"""
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if low >= high: raise ValueError(f"Tensor.uniform requires low < high, got {low=}, {high=}")
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.scaled_uniform(2, 3).numpy())
```
"""
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor:
"""
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.glorot_uniform(2, 3).numpy())
```
"""
bound = (6 / (argfix(*shape)[0]+prod(argfix(*shape)[1:]))) ** 0.5
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_uniform(2, 3).numpy())
```
"""
bound = (6 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
@staticmethod
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_normal(2, 3).numpy())
```
"""
std = (2 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
@staticmethod
def randperm(n:int, device=None, dtype=dtypes.int32, **kwargs) -> Tensor:
"""
Returns a tensor with a random permutation of integers from `0` to `n-1`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randperm(6).numpy())
```
"""
return Tensor.rand(n, device=device, **kwargs).argsort().cast(dtype)
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
"""
Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(20, replacement=True).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor([1, 2, 3, 4])
print(t.multinomial(3, replacement=False).numpy())
```
"""
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
weight = self.unsqueeze(0) if self.ndim == 1 else self
assert replacement or num_samples <= weight.shape[1], "no replacement samples must not exceed population size"
if replacement or num_samples == 1:
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device)
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
else:
# EfraimidisSpirakis
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
# ***** toposort and backward pass *****
def backward(self, gradient:Tensor|None=None) -> Tensor:
@ -817,46 +551,6 @@ class Tensor(RandMixin):
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
def __getitem__(self, indices) -> Tensor:
"""
Retrieves a sub-tensor using indexing.
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
Examples:
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(12).reshape(3, 4)
print(t.numpy())
```
- Int Indexing: Select an element or sub-tensor using integers for each dimension.
```python exec="true" source="above" session="tensor" result="python"
print(t[1, 2].numpy())
```
- Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
```python exec="true" source="above" session="tensor" result="python"
print(t[0:2, ::2].numpy())
```
- Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
```python exec="true" source="above" session="tensor" result="python"
print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
```
- `None` Indexing: Add a new dimension to the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(t[:, None].shape)
```
NOTE: Out-of-bounds indexing results in a value of `0`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t[Tensor([4, 3, 2])].numpy())
```
"""
return super().__getitem__(indices)
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}")
# raise if mutation would diverge from eager (allow only pure views of a realized buffer; exclude +=/-= RHS via v_uop/v_bw)
@ -887,68 +581,6 @@ class Tensor(RandMixin):
def __delitem__(self, indices) -> None:
raise TypeError("Tensor does not support deleting items")
def masked_select(self, mask, size:int|None=None, fill_value:ConstType=0):
"""
Selects elements from `self` based on the boolean `mask`.
With `size=None` (default), output length equals the number of `True` values (not jittable).
With `size=N`, output length is `N`, padded with `fill_value` or truncated (jittable).
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
mask = Tensor([[True, False, True], [False, True, False], [False, False, True]])
print(t.numpy())
print(mask.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.masked_select(mask).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.masked_select(mask, size=6, fill_value=-1).numpy())
```
"""
if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
mask_cumsum = mask.cumsum()
if size is None:
counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, buffer=False)
return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()]
counts = Tensor.zeros(size, dtype=dtypes.int32, buffer=False).scatter(0, mask_cumsum, 1, reduce='add')
return (Tensor.arange(size) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype)
def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Tensor:
"""
Returns the indices of the elements that are non-zero.
With `size=None` (default), output shape is `(n_nonzero, ndim)` (not jittable).
With `size=N`, output shape is `(N, ndim)`, padded with `fill_value` or truncated (jittable).
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 0, 2, 0, 3])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 0], [0, 2]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.nonzero(size=3, fill_value=-1).numpy())
```
"""
if self.ndim == 0:
return Tensor.zeros(size if size is not None else int((self != 0).item()), 0, dtype=dtypes.int32, device=self.device)
mask = (self != 0).flatten()
indices = Tensor.stack(*[Tensor.arange(s).reshape(*[1]*i, s, *[1]*(self.ndim-i-1)).expand(self.shape).flatten()
for i, s in enumerate(self.shape)], dim=-1)
return indices.masked_select(mask.unsqueeze(-1).expand(*mask.shape, self.ndim),
size=size*self.ndim if size is not None else None, fill_value=fill_value).reshape(-1, self.ndim)
# ***** reduce ops *****
def keccak(self, cfg:str|tuple[int, int]="sha3_256"):
@ -1054,11 +686,11 @@ class Tensor(RandMixin):
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
# compute 6x6 winograd tiles: GgGt, BtdB
# compute 6x6 winograd tiles: GgGt, BtdB. contiguous so the transforms are materialized once
# (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
# (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx)
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), dtype=dtype), len(HW))
@ -1227,25 +859,6 @@ class Tensor(RandMixin):
# ***** cast ops *****
def cast(self, dtype:DTypeLike) -> Tensor:
"""
Casts `self` to the given `dtype`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.uint8)
print(t.dtype, t.numpy())
```
"""
return self if self.dtype == (dt:=to_dtype(dtype)) else self._apply_uop(UOp.cast, dtype=dt)
def bitcast(self, dtype:DTypeLike) -> Tensor:
"""
Bitcasts `self` to the given `dtype` of the same itemsize.

View file

@ -561,10 +561,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=())
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != () and ret.shape != shape else ret
@staticmethod
def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> UOp:
dt = to_dtype(dtype) if dtype is not None else dtypes.from_py(Invalid)
return UOp.const(dt, Invalid, shape=shape).clone(device=device)
@staticmethod
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs):
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
@staticmethod