mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Merge branch 'master' into codegen2
This commit is contained in:
commit
d319b5f614
15 changed files with 493 additions and 492 deletions
|
|
@ -2674,14 +2674,14 @@ def custom_hk_mxfp8_gemm(C:UOp, A:UOp, B:UOp, scale_A:UOp, scale_B:UOp, *extra:U
|
|||
|
||||
def quantize_mxfp8(x:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# 1x32 block scaling along the last axis
|
||||
rows, K = x.shape
|
||||
*batch, K = x.shape
|
||||
scale_K, k_iters = K // 32, K // 128
|
||||
amax = x.detach().float().reshape(rows, scale_K, 32).abs().max(axis=-1)
|
||||
e8 = (amax.maximum(1e-38).log2().floor() + 127).clamp(0, 254).cast(dtypes.uint8)
|
||||
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(rows, scale_K, 1).expand(rows, scale_K, 32).reshape(rows, K)
|
||||
qscale = (127.0 - e8.cast(dtypes.float32)).exp2().reshape(*batch, scale_K, 1).expand(*batch, scale_K, 32).reshape(*batch, K)
|
||||
x_scaled = x.float() * qscale
|
||||
x_clamped = x_scaled + (x_scaled.detach().clamp(-448.0, 448.0) - x_scaled.detach()) # STE
|
||||
return x_clamped.cast(FP8_DTYPE), e8, mx_pack(e8)
|
||||
return x_clamped.cast(FP8_DTYPE), e8, (mx_pack(e8) if len(batch) == 1 else None)
|
||||
|
||||
def mx_pack(e8:Tensor) -> Tensor:
|
||||
rows, scale_K = e8.shape
|
||||
|
|
|
|||
|
|
@ -317,6 +317,19 @@ class TestTensorUOpScatterReduce(unittest.TestCase):
|
|||
def test_mean_exclude_self(self):
|
||||
self._check(_t(3, 4).float(), Tensor([[0, 1, 0, 1]]*3, dtype=dtypes.int32), Tensor.ones(3, 4).float(), reduce="mean", include_self=False)
|
||||
|
||||
class TestTensorUOpMaskedSelect(unittest.TestCase):
|
||||
# only the fixed-size path is pure
|
||||
def _check(self, t, mask, **kw):
|
||||
self.assertIs(t.masked_select(mask, **kw).uop, t.uop.masked_select(mask.uop, **kw))
|
||||
def test_masked_select_1d(self): self._check(_t(6), Tensor([True, False, True, False, True, False]), size=4)
|
||||
def test_masked_select_2d(self):
|
||||
self._check(_t(3, 3), Tensor([[True, False, True], [False, True, False], [False, False, True]]), size=6, fill_value=-1)
|
||||
|
||||
class TestTensorUOpNonzero(unittest.TestCase):
|
||||
def _check(self, t, **kw): self.assertIs(t.nonzero(**kw).uop, t.uop.nonzero(**kw))
|
||||
def test_nonzero_1d(self): self._check(_t(5), size=3)
|
||||
def test_nonzero_2d(self): self._check(_t(2, 3), size=4)
|
||||
|
||||
class TestTensorUOpPool(unittest.TestCase):
|
||||
def test_avg_pool2d(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d())
|
||||
def test_avg_pool2d_padding(self): _check(self, _t(1, 1, 5, 5).float(), lambda x: x.avg_pool2d(padding=1))
|
||||
|
|
@ -462,6 +475,10 @@ class TestTensorUOpCreation(unittest.TestCase):
|
|||
self.assertIs(_strip_unique(Tensor.ones(2, 3).uop), _strip_unique(UOp.ones(2, 3)))
|
||||
def test_invalids(self):
|
||||
self.assertIs(_strip_unique(Tensor.invalids(2, 3, dtype=dtypes.int8).uop), _strip_unique(UOp.invalids((2, 3), dtype=dtypes.int8)))
|
||||
def test_empty_like(self):
|
||||
t = Tensor.empty(2, 3, dtype=dtypes.int8)
|
||||
self.assertIs(_strip_unique(t.empty_like().uop), _strip_unique(t.uop.empty_like()))
|
||||
self.assertIs(_strip_unique(t.empty_like(dtype=dtypes.float, device="NULL").uop), _strip_unique(t.uop.empty_like(dtypes.float, "NULL")))
|
||||
def test_arange(self):
|
||||
self.assertIs(Tensor.arange(5).uop, UOp.arange(5))
|
||||
def test_arange_empty(self):
|
||||
|
|
|
|||
|
|
@ -13,35 +13,26 @@ class TestWinograd(unittest.TestCase):
|
|||
def test_forward_kernels(self):
|
||||
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
|
||||
out = Tensor.conv2d(x,w)
|
||||
self.assertEqual(len(out.schedule_linear().src), 2)
|
||||
self.assertEqual(len(out.schedule_linear().src), 4)
|
||||
|
||||
def test_backward_kernels(self):
|
||||
x,w = Tensor.empty(1,4,9,9).realize(), Tensor.empty(4,4,3,3).realize()
|
||||
out = Tensor.conv2d(x,w, padding=1)
|
||||
out.mean().backward()
|
||||
backward_schedule = x.grad.schedule_linear(w.grad)
|
||||
self.assertEqual(len(backward_schedule.src), 2)
|
||||
self.assertEqual(len(backward_schedule.src), 4)
|
||||
|
||||
@unittest.skip("this requires optimizations")
|
||||
def test_counters(self):
|
||||
IC, OC, X, Y = 4,4,9,9
|
||||
x,w = Tensor.rand(1,IC,Y,X).realize(), Tensor.rand(OC,IC,3,3).realize()
|
||||
IC, OC, H = 64, 64, 28
|
||||
x,w = Tensor.empty(1,IC,H,H,device="NULL").realize(), Tensor.empty(OC,IC,3,3,device="NULL").realize()
|
||||
GlobalCounters.reset()
|
||||
with Context(WINO=1):
|
||||
Tensor.conv2d(x,w).realize()
|
||||
ops_wino, mem_wino = GlobalCounters.global_ops, GlobalCounters.global_mem
|
||||
with Context(NOOPT=0, WINO=1): Tensor.conv2d(x,w).realize()
|
||||
ops_wino = GlobalCounters.global_ops
|
||||
GlobalCounters.reset()
|
||||
with Context(WINO=0):
|
||||
Tensor.conv2d(x,w).realize()
|
||||
ops_normal, mem_normal = GlobalCounters.global_ops, GlobalCounters.global_mem
|
||||
|
||||
ops_ratio, mem_ratio = ops_wino/ops_normal, mem_wino/mem_normal
|
||||
print(f"ops: normal {ops_normal:9d} wino {ops_wino:9d} ratio {ops_ratio:.2f}")
|
||||
print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}")
|
||||
|
||||
# TODO: what's optimal on this?
|
||||
self.assertLess(ops_ratio, 4.3)
|
||||
self.assertLess(mem_ratio, 4)
|
||||
with Context(NOOPT=0, WINO=0): Tensor.conv2d(x,w).realize()
|
||||
ops_normal = GlobalCounters.global_ops
|
||||
print(f"ops: normal {ops_normal} wino {ops_wino} ratio {ops_wino/ops_normal:.2f}")
|
||||
self.assertLess(ops_wino/ops_normal, 0.6)
|
||||
|
||||
def test_dtype(self):
|
||||
IC, OC, X, Y = 4,4,9,9
|
||||
|
|
|
|||
|
|
@ -222,7 +222,7 @@ class TestCallSchedule(unittest.TestCase):
|
|||
# find the FUNCTION nodes
|
||||
c0 = next(u for u in r0.uop.toposort() if u.op is Ops.FUNCTION)
|
||||
c1 = next(u for u in r1.uop.toposort() if u.op is Ops.FUNCTION)
|
||||
# the function bodies (src[0]) should have identical keys — unique consts must not leak through
|
||||
# the function bodies (src[0]) should have identical keys
|
||||
self.assertEqual(c0.src[0].key, c1.src[0].key)
|
||||
|
||||
def test_precompile_symbolic_2d(self):
|
||||
|
|
|
|||
|
|
@ -231,8 +231,7 @@ def _prepare_jit_inputs(args, kwargs):
|
|||
it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else []
|
||||
tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)]
|
||||
def get_input_uops() -> list[UOp]: return flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
||||
# TODO: drop the CONST branch once all CONST are deviceless
|
||||
if any(u.device is None or u.base.op is Ops.CONST for u in get_input_uops()): raise JitError("JIT inputs must be real buffers; use .clone()")
|
||||
if any(u.device is None for u in get_input_uops()): raise JitError("JIT inputs must be real buffers; use .clone()")
|
||||
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
|
||||
input_uops = get_input_uops()
|
||||
# collect buffer UOps (including MultiBuffer)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
from __future__ import annotations
|
||||
import functools, itertools
|
||||
from typing import TYPE_CHECKING, Callable, Self, Sequence, Literal, get_args
|
||||
import functools, itertools, string
|
||||
from typing import TYPE_CHECKING, Callable, Self, Sequence, Literal, get_args, cast
|
||||
from tinygrad.mixin.elementwise import ElementwiseMixin
|
||||
from tinygrad.mixin.movement import MovementMixin
|
||||
from tinygrad.mixin.reduce import ReduceMixin
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
|
||||
from tinygrad.dtype import ConstType, DTypeLike, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
|
||||
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, prod
|
||||
from tinygrad.dtype import ConstType, DTypeLike, Invalid, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
|
||||
from tinygrad.helpers import all_int, argfix, argsort, ceildiv, flatten, flat_to_grouped, fully_flatten, get_shape, make_tuple, merge_dicts, prod
|
||||
from tinygrad.helpers import resolve_pool_pads, round_up
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -20,6 +20,37 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
@staticmethod
|
||||
def const(dtype, b): raise NotImplementedError
|
||||
|
||||
def data(self) -> memoryview: raise NotImplementedError("data requires Tensor realization to host memory")
|
||||
|
||||
def item(self) -> PyConst:
|
||||
"""
|
||||
Returns the value of this tensor as a standard Python number.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor(42)
|
||||
print(t.item())
|
||||
```
|
||||
"""
|
||||
assert self.numel() == 1, "must have one element for item"
|
||||
return self.data()[(0,) * len(self.shape)]
|
||||
|
||||
def _multi_like(self, fxn:Callable[[tuple[sint, ...], str|None], Self]) -> Self:
|
||||
from tinygrad.uop.ops import UOp
|
||||
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
|
||||
if self._uop.axis is None: return self._wrap_uop(fxn(self.shape, None)._uop.shard(self.device, None))
|
||||
return self._wrap_uop(UOp.mstack(*[fxn(self._uop.shard_shape, d)._uop for d in self.device]).multi(self._uop.axis))
|
||||
|
||||
@classmethod
|
||||
def invalids(cls, *shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Self:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with Invalid.
|
||||
|
||||
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
|
||||
|
||||
Eventually Tensor.empty will be replaced by this.
|
||||
"""
|
||||
return cls.full(argfix(*shape), Invalid, dtype=dtype, device=device)
|
||||
|
||||
@classmethod
|
||||
def full(cls, shape:tuple[sint, ...], fill_value:ConstType|UOp, dtype:DTypeLike|None=None,
|
||||
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
|
||||
|
|
@ -45,7 +76,45 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
|
||||
return val.clone(device=device) if buffer else val
|
||||
|
||||
def __getitem__(self, indices) -> Self: return self._getitem(indices)
|
||||
def __getitem__(self, indices) -> Self:
|
||||
"""
|
||||
Retrieves a sub-tensor using indexing.
|
||||
|
||||
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
|
||||
|
||||
Examples:
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(12).reshape(3, 4)
|
||||
print(t.numpy())
|
||||
```
|
||||
|
||||
- Int Indexing: Select an element or sub-tensor using integers for each dimension.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[1, 2].numpy())
|
||||
```
|
||||
|
||||
- Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[0:2, ::2].numpy())
|
||||
```
|
||||
|
||||
- Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
|
||||
```
|
||||
|
||||
- `None` Indexing: Add a new dimension to the tensor.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[:, None].shape)
|
||||
```
|
||||
|
||||
NOTE: Out-of-bounds indexing results in a value of `0`.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([1, 2, 3])
|
||||
print(t[Tensor([4, 3, 2])].numpy())
|
||||
```
|
||||
"""
|
||||
return self._getitem(indices)
|
||||
|
||||
def _getitem(self, indices, v=None) -> Self:
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
|
@ -367,7 +436,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode)
|
||||
raise NotImplementedError(f"{mode=} is not supported")
|
||||
|
||||
def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]:
|
||||
def _broadcasted(self, y:Self|ConstType|UOp, reverse:bool=False) -> tuple[Self, Self]:
|
||||
if not isinstance(y, type(self)): y = self.ufix(y)
|
||||
x, y = (self, y) if not reverse else (y, self)
|
||||
# ValueError: unsized ptr has shape (-1,) which can't broadcast; RuntimeError: shape mismatch
|
||||
|
|
@ -423,6 +492,47 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
def __matmul__(self, x:Self) -> Self: return self.matmul(x)
|
||||
def __rmatmul__(self, x:Self) -> Self: return self.matmul(x, True)
|
||||
|
||||
@classmethod
|
||||
def einsum(cls, formula:str, *operands:Self|Sequence[Self], dtype:DTypeLike|None=None) -> Self:
|
||||
"""
|
||||
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
|
||||
|
||||
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
x = Tensor([[1, 2], [3, 4]])
|
||||
y = Tensor([[5, 6], [7, 8]])
|
||||
print(Tensor.einsum("ij,ij->", x, y).numpy())
|
||||
```
|
||||
"""
|
||||
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
|
||||
# expand ellipsis to letters, determine output
|
||||
if "..." in formula:
|
||||
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
|
||||
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
|
||||
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
|
||||
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
|
||||
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
|
||||
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
|
||||
inputs = lhs.split(",")
|
||||
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
|
||||
# trace: take diagonal when letter repeats in single input
|
||||
for i, (s, x) in enumerate(zip(inputs, xs)):
|
||||
for c in set(s):
|
||||
while s.count(c) > 1:
|
||||
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
|
||||
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
|
||||
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
|
||||
s = s[:k] + s[k+1:]
|
||||
inputs[i], xs[i] = s, x
|
||||
# check sizes and build sorted alphabet
|
||||
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
|
||||
alpha = sorted(sz)
|
||||
# align all tensors to alphabet, multiply, sum non-output, permute to output order
|
||||
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
|
||||
for s, x in zip(inputs, xs)]
|
||||
return xs[0].uprod(*xs[1:]).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))
|
||||
|
||||
def gradient(self, *targets:Self, gradient:Self|None=None) -> list[Self]:
|
||||
"""
|
||||
Computes the gradient of the targets with respect to self.
|
||||
|
|
@ -1147,6 +1257,68 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
# select from values for each True element in mask else select from self
|
||||
return mask.where(values, self)
|
||||
|
||||
def masked_select(self, mask, size:int|None=None, fill_value:ConstType=0):
|
||||
"""
|
||||
Selects elements from `self` based on the boolean `mask`.
|
||||
|
||||
With `size=None` (default), output length equals the number of `True` values (not jittable).
|
||||
With `size=N`, output length is `N`, padded with `fill_value` or truncated (jittable).
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
|
||||
mask = Tensor([[True, False, True], [False, True, False], [False, False, True]])
|
||||
print(t.numpy())
|
||||
print(mask.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.masked_select(mask).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.masked_select(mask, size=6, fill_value=-1).numpy())
|
||||
```
|
||||
"""
|
||||
if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
|
||||
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
|
||||
mask_cumsum = mask.cumsum()
|
||||
if size is None:
|
||||
counts = type(self).zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, buffer=False)
|
||||
return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()]
|
||||
counts = type(self).zeros(size, dtype=dtypes.int32, buffer=False).scatter(0, mask_cumsum, 1, reduce='add')
|
||||
return (type(self).arange(size) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype)
|
||||
|
||||
def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Self:
|
||||
"""
|
||||
Returns the indices of the elements that are non-zero.
|
||||
|
||||
With `size=None` (default), output shape is `(n_nonzero, ndim)` (not jittable).
|
||||
With `size=N`, output shape is `(N, ndim)`, padded with `fill_value` or truncated (jittable).
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([1, 0, 2, 0, 3])
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.nonzero().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[1, 0], [0, 2]])
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.nonzero().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.nonzero(size=3, fill_value=-1).numpy())
|
||||
```
|
||||
"""
|
||||
if self.ndim == 0:
|
||||
return type(self).zeros(size if size is not None else int(self.ne(0).item()), 0, dtype=dtypes.int32, device=self.device)
|
||||
mask = self.ne(0).flatten()
|
||||
indices = type(self).stack(*[type(self).arange(s).reshape(*[1]*i, s, *[1]*(self.ndim-i-1)).expand(self.shape).flatten()
|
||||
for i, s in enumerate(self.shape)], dim=-1)
|
||||
return indices.masked_select(mask.unsqueeze(-1).expand(*mask.shape, self.ndim),
|
||||
size=size*self.ndim if size is not None else None, fill_value=fill_value).reshape(-1, self.ndim)
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
def sequential(self, ll:list[Callable[[Self], Self]]) -> Self:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,16 @@
|
|||
from typing import Self
|
||||
from tinygrad.dtype import ConstType, DType
|
||||
from tinygrad.dtype import ConstType, DType, DTypeLike
|
||||
from tinygrad.mixin.dtype import DTypeMixin
|
||||
|
||||
class CreationMixin:
|
||||
def const_like(self, b: ConstType) -> Self: raise NotImplementedError
|
||||
def cast(self, dtype: DType) -> Self: raise NotImplementedError
|
||||
class CreationMixin(DTypeMixin):
|
||||
def const_like(self, b: ConstType) -> Self: return self._wrap_uop(self._uop.const_like(b))
|
||||
|
||||
def empty_like(self, dtype: DTypeLike|None=None, device: str|tuple[str, ...]|None=None) -> Self:
|
||||
"""
|
||||
Creates an empty tensor with the same shape as `self`.
|
||||
If `dtype` is not specified, the dtype of `self` is used.
|
||||
"""
|
||||
return self._wrap_uop(self._uop.empty_like(dtype, device))
|
||||
|
||||
def full_like(self, fill_value: ConstType, dtype: DType|None=None) -> Self:
|
||||
"""Creates a tensor with the same shape as `self`, filled with the given value."""
|
||||
|
|
|
|||
|
|
@ -1,13 +1,36 @@
|
|||
from typing import Self
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from typing import TYPE_CHECKING, Self
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, to_dtype
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
class DTypeMixin:
|
||||
@property
|
||||
def dtype(self) -> DType: raise NotImplementedError
|
||||
@property
|
||||
def _uop(self) -> 'UOp': raise NotImplementedError
|
||||
def _wrap_uop(self, u:'UOp') -> Self: raise NotImplementedError
|
||||
|
||||
def cast(self, dtype:DType) -> Self: raise NotImplementedError
|
||||
def cast(self, dtype:DTypeLike) -> Self:
|
||||
"""
|
||||
Casts `self` to the given `dtype`.
|
||||
|
||||
def bitcast(self, dtype:DType) -> Self: raise NotImplementedError
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.cast(dtypes.int32)
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.cast(dtypes.uint8)
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
"""
|
||||
return self if self.dtype == (dt:=to_dtype(dtype)) else self._wrap_uop(self._uop.cast(dt))
|
||||
|
||||
def bitcast(self, dtype:DTypeLike) -> Self: raise NotImplementedError
|
||||
|
||||
def element_size(self) -> int:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,23 +3,17 @@ from typing import TYPE_CHECKING, Literal, Self
|
|||
from tinygrad.uop import Ops
|
||||
from tinygrad.dtype import dtypes, ConstType, PyConst, least_upper_dtype, least_upper_float
|
||||
from tinygrad.helpers import argfix, polyN
|
||||
from tinygrad.mixin.dtype import DTypeMixin
|
||||
from tinygrad.mixin.creation import CreationMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
|
||||
class ElementwiseMixin(DTypeMixin, CreationMixin):
|
||||
class ElementwiseMixin(CreationMixin):
|
||||
# required to implement
|
||||
def alu(self, op: Ops, *src: Self) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _uop(self) -> 'UOp': raise NotImplementedError
|
||||
|
||||
def _wrap_uop(self, u: 'UOp') -> Self: raise NotImplementedError
|
||||
|
||||
# great functions you get!
|
||||
def ufix(self, x: 'Self|ConstType|UOp') -> Self:
|
||||
return x if isinstance(x, type(self)) else self._wrap_uop(self._uop.ufix(x))
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
|
|||
params = {x.arg.slot:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
|
||||
grad_args = ctx.src
|
||||
root_grad = UOp(Ops.TUPLE, src=tuple(UOp(Ops.NOOP) if g.op is Ops.NOOP else
|
||||
g if g.base.op is Ops.CONST and g.device is None else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
|
||||
g if g.base.op is Ops.CONST else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
|
||||
grads = compute_gradient(fxn, root_grad, set(params.values()))
|
||||
# for precompiled calls, substitute forward outputs with params so intermediates aren't recomputed
|
||||
fwd_subs = {src: src.param_like(len(args)+len(grad_args)+i) for i, src in enumerate(fxn.src)} if k.arg.precompile else {}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from __future__ import annotations
|
||||
from typing import Self
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.helpers import ceildiv, prod
|
||||
import math
|
||||
from typing import Self, cast
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, to_dtype
|
||||
from tinygrad.helpers import all_int, argfix, ceildiv, prod
|
||||
from tinygrad.mixin import OpMixin
|
||||
from tinygrad.device import canonicalize_device
|
||||
|
||||
|
||||
class RandMixin(OpMixin):
|
||||
|
|
@ -39,3 +41,235 @@ class RandMixin(OpMixin):
|
|||
bits = cls.random_bits(key, counter, ceildiv(prod(shape) * dtype.itemsize, 4))
|
||||
out = cls._bits_to_rand(bits, shape, dtype)
|
||||
return out.contiguous() if contiguous else out
|
||||
|
||||
@staticmethod
|
||||
def _next_counter(device:str, num:int):
|
||||
raise NotImplementedError("_next_counter requires the stateful per-device RNG counter, only implemented on Tensor")
|
||||
|
||||
@classmethod
|
||||
def rand(cls, *shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True) -> Self:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.rand(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
dt = to_dtype(dtype or dtypes.default_float)
|
||||
if not dtypes.is_float(dt): raise ValueError(f"rand only supports float dtypes, got {dt}")
|
||||
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
||||
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
|
||||
device = cast(str, canonicalize_device(device))
|
||||
key, counter = cls._next_counter(device, ceildiv(prod(shape) * dt.itemsize, 4))
|
||||
return cls._rand(key, counter, shape, dt, contiguous=contiguous)
|
||||
|
||||
def rand_like(self, **kwargs) -> Self:
|
||||
"""
|
||||
Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.ones(2, 3)
|
||||
print(Tensor.rand_like(t).numpy())
|
||||
```
|
||||
"""
|
||||
if isinstance(self.device, tuple):
|
||||
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
||||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
return self._multi_like(lambda shape, dev: type(self).rand(*shape, dtype=dtype, device=dev, **kwargs))
|
||||
return type(self).rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
|
||||
|
||||
def randn_like(self, dtype:DTypeLike|None=None, **kwargs) -> Self:
|
||||
"""
|
||||
Creates a tensor with the same shape and sharding as `self`, filled with random values from a normal distribution with mean 0 and variance 1.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.ones(2, 3)
|
||||
print(Tensor.randn_like(t).numpy())
|
||||
```
|
||||
"""
|
||||
src = self.stack(self).rand_like(**{**kwargs, "dtype": dtypes.float32})
|
||||
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
||||
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(to_dtype(dtype or self.dtype))
|
||||
|
||||
@classmethod
|
||||
def randn(cls, *shape, dtype:DTypeLike|None=None, **kwargs) -> Self:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
|
||||
If `dtype` is not specified, the default type is used.
|
||||
|
||||
You can pass in the `device` keyword argument to control device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.randn(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
return cls.empty(*shape, **kwargs).randn_like(dtype=dtype) # type: ignore[attr-defined]
|
||||
|
||||
@classmethod
|
||||
def randint(cls, *shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Self:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
|
||||
Requires `low < high`. If `dtype` is not specified, the default type is used.
|
||||
|
||||
You can pass in the `device` keyword argument to control device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.randint(2, 3, low=5, high=10).numpy())
|
||||
```
|
||||
"""
|
||||
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
|
||||
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
|
||||
if low >= high: raise ValueError(f"Tensor.randint requires low < high, got {low=}, {high=}")
|
||||
return cls.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def normal(cls, *shape, mean=0.0, std=1.0, **kwargs) -> Self:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
|
||||
Requires `std >= 0`.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
|
||||
```
|
||||
"""
|
||||
if std < 0: raise ValueError(f"Tensor.normal requires std >= 0, got {std=}")
|
||||
return std * cls.randn(*shape, **kwargs) + mean
|
||||
|
||||
@classmethod
|
||||
def uniform(cls, *shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, **kwargs) -> Self:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
|
||||
Requires `low < high`.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
|
||||
```
|
||||
"""
|
||||
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
||||
if low >= high: raise ValueError(f"Tensor.uniform requires low < high, got {low=}, {high=}")
|
||||
return ((high-low) * cls.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low
|
||||
|
||||
@classmethod
|
||||
def scaled_uniform(cls, *shape, **kwargs) -> Self:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a uniform distribution
|
||||
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.scaled_uniform(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
return cls.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
|
||||
|
||||
@classmethod
|
||||
def glorot_uniform(cls, *shape, **kwargs) -> Self:
|
||||
"""
|
||||
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.glorot_uniform(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
bound = (6 / (argfix(*shape)[0]+prod(argfix(*shape)[1:]))) ** 0.5
|
||||
return cls.uniform(*shape, low=-bound, high=bound, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def kaiming_uniform(cls, *shape, a:float = 0.01, **kwargs) -> Self:
|
||||
"""
|
||||
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.kaiming_uniform(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
bound = (6 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
|
||||
return cls.uniform(*shape, low=-bound, high=bound, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def kaiming_normal(cls, *shape, a:float = 0.01, **kwargs) -> Self:
|
||||
"""
|
||||
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.kaiming_normal(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
std = (2 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
|
||||
return cls.normal(*shape, mean=0.0, std=std, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def randperm(cls, n:int, device=None, dtype=dtypes.int32, **kwargs) -> Self:
|
||||
"""
|
||||
Returns a tensor with a random permutation of integers from `0` to `n-1`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.randperm(6).numpy())
|
||||
```
|
||||
"""
|
||||
return cls.rand(n, device=device, **kwargs).argsort().cast(dtype)
|
||||
|
||||
def multinomial(self, num_samples:int = 1, replacement:bool = False) -> Self:
|
||||
"""
|
||||
Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor([1, 2, 3, 4])
|
||||
print(t.multinomial(20, replacement=True).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor([1, 2, 3, 4])
|
||||
print(t.multinomial(3, replacement=False).numpy())
|
||||
```
|
||||
"""
|
||||
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
|
||||
weight = self.unsqueeze(0) if self.ndim == 1 else self
|
||||
assert replacement or num_samples <= weight.shape[1], "no replacement samples must not exceed population size"
|
||||
if replacement or num_samples == 1:
|
||||
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
|
||||
unif_samples = type(self).rand(num_samples, cdf.shape[0], 1).to(self.device) # type: ignore[attr-defined]
|
||||
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
|
||||
else:
|
||||
# Efraimidis-Spirakis
|
||||
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
|
||||
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
import string
|
||||
from typing import Self, Sequence, cast
|
||||
from typing import Self, Sequence
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.dtype import DTypeLike, dtypes, sum_acc_dtype, to_dtype
|
||||
from tinygrad.helpers import argfix, argsort, make_tuple, merge_dicts
|
||||
from tinygrad.helpers import make_tuple
|
||||
from tinygrad.mixin.dtype import DTypeMixin
|
||||
from tinygrad.mixin.movement import MovementMixin
|
||||
|
||||
|
|
@ -136,44 +135,3 @@ class ReduceMixin(DTypeMixin, MovementMixin):
|
|||
```
|
||||
"""
|
||||
return self.bool().prod(axis, keepdim)
|
||||
|
||||
@classmethod
|
||||
def einsum(cls, formula:str, *operands:Self|Sequence[Self], dtype:DTypeLike|None=None) -> Self:
|
||||
"""
|
||||
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
|
||||
|
||||
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
x = Tensor([[1, 2], [3, 4]])
|
||||
y = Tensor([[5, 6], [7, 8]])
|
||||
print(Tensor.einsum("ij,ij->", x, y).numpy())
|
||||
```
|
||||
"""
|
||||
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
|
||||
# expand ellipsis to letters, determine output
|
||||
if "..." in formula:
|
||||
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
|
||||
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
|
||||
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
|
||||
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
|
||||
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
|
||||
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
|
||||
inputs = lhs.split(",")
|
||||
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
|
||||
# trace: take diagonal when letter repeats in single input
|
||||
for i, (s, x) in enumerate(zip(inputs, xs)):
|
||||
for c in set(s):
|
||||
while s.count(c) > 1:
|
||||
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
|
||||
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
|
||||
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
|
||||
s = s[:k] + s[k+1:]
|
||||
inputs[i], xs[i] = s, x
|
||||
# check sizes and build sorted alphabet
|
||||
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
|
||||
alpha = sorted(sz)
|
||||
# align all tensors to alphabet, multiply, sum non-output, permute to output order
|
||||
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
|
||||
for s, x in zip(inputs, xs)]
|
||||
return xs[0].uprod(*xs[1:]).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))
|
||||
|
|
|
|||
|
|
@ -543,9 +543,6 @@ to_define_global = PatternMatcher([
|
|||
# remove device from local BUFFERIZE
|
||||
(UPat(Ops.STAGE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),
|
||||
|
||||
# remove UNIQUE/DEVICE to dedup CONST
|
||||
(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
|
||||
|
||||
# renumber the ranges starting with 0 so that kernel deduping works
|
||||
(UPat(Ops.RANGE, name="r"), renumber_range),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -136,24 +136,13 @@ class Tensor(RandMixin):
|
|||
all_tensors[weakref.ref(ret)] = None
|
||||
return ret
|
||||
|
||||
# alu and const_like are used by the mixins
|
||||
# alu, _uop, _wrap_uop and const are used by the mixins
|
||||
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
|
||||
@property
|
||||
def _uop(self) -> UOp: return self.uop
|
||||
def _wrap_uop(self, u:UOp) -> Tensor: return Tensor(u)
|
||||
def const_like(self, b:ConstType) -> Tensor: return Tensor(self.uop.const_like(b))
|
||||
@staticmethod
|
||||
def const(dtype:DType, b:ConstType|UOp) -> Tensor: return Tensor(UOp.const(dtype, b))
|
||||
@staticmethod
|
||||
def invalids(*shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with Invalid.
|
||||
|
||||
This is an alternative to Tensor.empty when you want an "anonymous" buffer.
|
||||
|
||||
Eventually Tensor.empty will be replaced by this.
|
||||
"""
|
||||
return Tensor(UOp.invalids(argfix(*shape), dtype, device))
|
||||
|
||||
def is_param_(self, is_param:bool=True) -> Tensor:
|
||||
self.is_param = is_param
|
||||
|
|
@ -292,18 +281,6 @@ class Tensor(RandMixin):
|
|||
assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
|
||||
return self._buffer().as_memoryview().cast(self.dtype.base.fmt, self.shape)
|
||||
|
||||
def item(self) -> PyConst:
|
||||
"""
|
||||
Returns the value of this tensor as a standard Python number.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor(42)
|
||||
print(t.item())
|
||||
```
|
||||
"""
|
||||
assert self.numel() == 1, "must have one element for item"
|
||||
return self.data()[(0,) * len(self.shape)]
|
||||
|
||||
# NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions)
|
||||
def tolist(self) -> PyConst|list[Any]:
|
||||
"""
|
||||
|
|
@ -464,13 +441,6 @@ class Tensor(RandMixin):
|
|||
"""
|
||||
return Tensor(UOp.empty(argfix(*shape), dtype, device))
|
||||
|
||||
def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> Tensor:
|
||||
"""
|
||||
Creates an empty tensor with the same shape as `self`.
|
||||
If `dtype` is not specified, the dtype of `self` is used.
|
||||
"""
|
||||
return Tensor(self.uop.empty_like(dtype, device))
|
||||
|
||||
@staticmethod
|
||||
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
|
||||
"""
|
||||
|
|
@ -533,35 +503,8 @@ class Tensor(RandMixin):
|
|||
high = counter[1:2] - (num >> 32) - (counter[0] < (num & 0xffffffff))
|
||||
return Tensor._device_seeds[device], low.cat(high)
|
||||
|
||||
@staticmethod
|
||||
def rand(*shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.rand(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
dt = to_dtype(dtype or dtypes.default_float)
|
||||
if not dtypes.is_float(dt): raise ValueError(f"rand only supports float dtypes, got {dt}")
|
||||
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
||||
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
|
||||
device = cast(str, canonicalize_device(device))
|
||||
key, counter = Tensor._next_counter(device, ceildiv(prod(shape) * dt.itemsize, 4))
|
||||
return Tensor._rand(key, counter, shape, dt, contiguous=contiguous)
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
||||
def _multi_like(self, fxn:Callable[[tuple[sint, ...], str|None], Tensor]) -> Tensor:
|
||||
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
|
||||
if self.uop.axis is None: return fxn(self.shape, None).shard(self.device)
|
||||
stacked = UOp.mstack(*[fxn(self.uop.shard_shape, d).uop for d in self.device])
|
||||
return Tensor(stacked.multi(self.uop.axis))
|
||||
|
||||
def full_like(self, fill_value:ConstType, dtype=None, device=None) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the same shape as `self`, filled with the given value.
|
||||
|
|
@ -579,215 +522,6 @@ class Tensor(RandMixin):
|
|||
return self._multi_like(lambda shape, dev: Tensor.full(shape, fill_value, dtype=dtype or self.dtype, device=dev))
|
||||
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device)
|
||||
|
||||
def rand_like(self, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.ones(2, 3)
|
||||
print(Tensor.rand_like(t).numpy())
|
||||
```
|
||||
"""
|
||||
if isinstance(self.device, tuple):
|
||||
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
||||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
return self._multi_like(lambda shape, dev: Tensor.rand(*shape, dtype=dtype, device=dev, **kwargs))
|
||||
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
|
||||
|
||||
# ***** random functions *****
|
||||
|
||||
def randn_like(self, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the same shape and sharding as `self`, filled with random values from a normal distribution with mean 0 and variance 1.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.ones(2, 3)
|
||||
print(Tensor.randn_like(t).numpy())
|
||||
```
|
||||
"""
|
||||
src = self.stack(self).rand_like(**{**kwargs, "dtype": dtypes.float32})
|
||||
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
||||
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or self.dtype)
|
||||
|
||||
@staticmethod
|
||||
def randn(*shape, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
|
||||
If `dtype` is not specified, the default type is used.
|
||||
|
||||
You can pass in the `device` keyword argument to control device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.randn(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.empty(*shape, **kwargs).randn_like(dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
|
||||
Requires `low < high`. If `dtype` is not specified, the default type is used.
|
||||
|
||||
You can pass in the `device` keyword argument to control device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.randint(2, 3, low=5, high=10).numpy())
|
||||
```
|
||||
"""
|
||||
if not all_int([low, high]): raise TypeError(f"{low=} and {high=} must be integers")
|
||||
if not dtypes.is_int(dtype := to_dtype(dtype)): raise TypeError(f"{dtype=} must be int")
|
||||
if low >= high: raise ValueError(f"Tensor.randint requires low < high, got {low=}, {high=}")
|
||||
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
|
||||
Requires `std >= 0`.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
|
||||
```
|
||||
"""
|
||||
if std < 0: raise ValueError(f"Tensor.normal requires std >= 0, got {std=}")
|
||||
return std * Tensor.randn(*shape, **kwargs) + mean
|
||||
|
||||
@staticmethod
|
||||
def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
|
||||
Requires `low < high`.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
|
||||
```
|
||||
"""
|
||||
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
||||
if low >= high: raise ValueError(f"Tensor.uniform requires low < high, got {low=}, {high=}")
|
||||
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low
|
||||
|
||||
@staticmethod
|
||||
def scaled_uniform(*shape, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a uniform distribution
|
||||
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.scaled_uniform(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
|
||||
|
||||
@staticmethod
|
||||
def glorot_uniform(*shape, **kwargs) -> Tensor:
|
||||
"""
|
||||
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.glorot_uniform(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
bound = (6 / (argfix(*shape)[0]+prod(argfix(*shape)[1:]))) ** 0.5
|
||||
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
"""
|
||||
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.kaiming_uniform(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
bound = (6 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
|
||||
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
"""
|
||||
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.kaiming_normal(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
std = (2 / (1 + a ** 2) / prod(argfix(*shape)[1:])) ** 0.5
|
||||
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def randperm(n:int, device=None, dtype=dtypes.int32, **kwargs) -> Tensor:
|
||||
"""
|
||||
Returns a tensor with a random permutation of integers from `0` to `n-1`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor.randperm(6).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.rand(n, device=device, **kwargs).argsort().cast(dtype)
|
||||
|
||||
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
|
||||
"""
|
||||
Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor([1, 2, 3, 4])
|
||||
print(t.multinomial(20, replacement=True).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor([1, 2, 3, 4])
|
||||
print(t.multinomial(3, replacement=False).numpy())
|
||||
```
|
||||
"""
|
||||
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
|
||||
weight = self.unsqueeze(0) if self.ndim == 1 else self
|
||||
assert replacement or num_samples <= weight.shape[1], "no replacement samples must not exceed population size"
|
||||
if replacement or num_samples == 1:
|
||||
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
|
||||
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device)
|
||||
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
|
||||
else:
|
||||
# Efraimidis–Spirakis
|
||||
indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
|
||||
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
|
||||
def backward(self, gradient:Tensor|None=None) -> Tensor:
|
||||
|
|
@ -817,46 +551,6 @@ class Tensor(RandMixin):
|
|||
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
|
||||
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
|
||||
|
||||
def __getitem__(self, indices) -> Tensor:
|
||||
"""
|
||||
Retrieves a sub-tensor using indexing.
|
||||
|
||||
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
|
||||
|
||||
Examples:
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(12).reshape(3, 4)
|
||||
print(t.numpy())
|
||||
```
|
||||
|
||||
- Int Indexing: Select an element or sub-tensor using integers for each dimension.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[1, 2].numpy())
|
||||
```
|
||||
|
||||
- Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[0:2, ::2].numpy())
|
||||
```
|
||||
|
||||
- Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
|
||||
```
|
||||
|
||||
- `None` Indexing: Add a new dimension to the tensor.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[:, None].shape)
|
||||
```
|
||||
|
||||
NOTE: Out-of-bounds indexing results in a value of `0`.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([1, 2, 3])
|
||||
print(t[Tensor([4, 3, 2])].numpy())
|
||||
```
|
||||
"""
|
||||
return super().__getitem__(indices)
|
||||
|
||||
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
|
||||
if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}")
|
||||
# raise if mutation would diverge from eager (allow only pure views of a realized buffer; exclude +=/-= RHS via v_uop/v_bw)
|
||||
|
|
@ -887,68 +581,6 @@ class Tensor(RandMixin):
|
|||
def __delitem__(self, indices) -> None:
|
||||
raise TypeError("Tensor does not support deleting items")
|
||||
|
||||
def masked_select(self, mask, size:int|None=None, fill_value:ConstType=0):
|
||||
"""
|
||||
Selects elements from `self` based on the boolean `mask`.
|
||||
|
||||
With `size=None` (default), output length equals the number of `True` values (not jittable).
|
||||
With `size=N`, output length is `N`, padded with `fill_value` or truncated (jittable).
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
|
||||
mask = Tensor([[True, False, True], [False, True, False], [False, False, True]])
|
||||
print(t.numpy())
|
||||
print(mask.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.masked_select(mask).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.masked_select(mask, size=6, fill_value=-1).numpy())
|
||||
```
|
||||
"""
|
||||
if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
|
||||
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
|
||||
mask_cumsum = mask.cumsum()
|
||||
if size is None:
|
||||
counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, buffer=False)
|
||||
return x[counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()]
|
||||
counts = Tensor.zeros(size, dtype=dtypes.int32, buffer=False).scatter(0, mask_cumsum, 1, reduce='add')
|
||||
return (Tensor.arange(size) < mask.sum()).where(x[counts.cumsum()], fill_value).cast(self.dtype)
|
||||
|
||||
def nonzero(self, size:int|None=None, fill_value:ConstType=0) -> Tensor:
|
||||
"""
|
||||
Returns the indices of the elements that are non-zero.
|
||||
|
||||
With `size=None` (default), output shape is `(n_nonzero, ndim)` (not jittable).
|
||||
With `size=N`, output shape is `(N, ndim)`, padded with `fill_value` or truncated (jittable).
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([1, 0, 2, 0, 3])
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.nonzero().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[1, 0], [0, 2]])
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.nonzero().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.nonzero(size=3, fill_value=-1).numpy())
|
||||
```
|
||||
"""
|
||||
if self.ndim == 0:
|
||||
return Tensor.zeros(size if size is not None else int((self != 0).item()), 0, dtype=dtypes.int32, device=self.device)
|
||||
mask = (self != 0).flatten()
|
||||
indices = Tensor.stack(*[Tensor.arange(s).reshape(*[1]*i, s, *[1]*(self.ndim-i-1)).expand(self.shape).flatten()
|
||||
for i, s in enumerate(self.shape)], dim=-1)
|
||||
return indices.masked_select(mask.unsqueeze(-1).expand(*mask.shape, self.ndim),
|
||||
size=size*self.ndim if size is not None else None, fill_value=fill_value).reshape(-1, self.ndim)
|
||||
|
||||
# ***** reduce ops *****
|
||||
|
||||
def keccak(self, cfg:str|tuple[int, int]="sha3_256"):
|
||||
|
|
@ -1054,11 +686,11 @@ class Tensor(RandMixin):
|
|||
|
||||
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
|
||||
|
||||
# compute 6x6 winograd tiles: GgGt, BtdB
|
||||
# compute 6x6 winograd tiles: GgGt, BtdB. contiguous so the transforms are materialized once
|
||||
# (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
|
||||
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
|
||||
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
|
||||
# (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
|
||||
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
|
||||
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx)
|
||||
|
||||
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
|
||||
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), dtype=dtype), len(HW))
|
||||
|
|
@ -1227,25 +859,6 @@ class Tensor(RandMixin):
|
|||
|
||||
# ***** cast ops *****
|
||||
|
||||
def cast(self, dtype:DTypeLike) -> Tensor:
|
||||
"""
|
||||
Casts `self` to the given `dtype`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.cast(dtypes.int32)
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.cast(dtypes.uint8)
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
"""
|
||||
return self if self.dtype == (dt:=to_dtype(dtype)) else self._apply_uop(UOp.cast, dtype=dt)
|
||||
|
||||
def bitcast(self, dtype:DTypeLike) -> Tensor:
|
||||
"""
|
||||
Bitcasts `self` to the given `dtype` of the same itemsize.
|
||||
|
|
|
|||
|
|
@ -561,10 +561,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=())
|
||||
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != () and ret.shape != shape else ret
|
||||
@staticmethod
|
||||
def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> UOp:
|
||||
dt = to_dtype(dtype) if dtype is not None else dtypes.from_py(Invalid)
|
||||
return UOp.const(dt, Invalid, shape=shape).clone(device=device)
|
||||
@staticmethod
|
||||
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs):
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
|
||||
@staticmethod
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue