mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
move Tensor.__getitem__ to mixin [PR] (#16689)
This commit is contained in:
parent
4618d27129
commit
b50da5c205
2 changed files with 39 additions and 41 deletions
|
|
@ -45,7 +45,45 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
|
||||
return val.clone(device=device) if buffer else val
|
||||
|
||||
def __getitem__(self, indices) -> Self: return self._getitem(indices)
|
||||
def __getitem__(self, indices) -> Self:
|
||||
"""
|
||||
Retrieves a sub-tensor using indexing.
|
||||
|
||||
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
|
||||
|
||||
Examples:
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(12).reshape(3, 4)
|
||||
print(t.numpy())
|
||||
```
|
||||
|
||||
- Int Indexing: Select an element or sub-tensor using integers for each dimension.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[1, 2].numpy())
|
||||
```
|
||||
|
||||
- Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[0:2, ::2].numpy())
|
||||
```
|
||||
|
||||
- Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
|
||||
```
|
||||
|
||||
- `None` Indexing: Add a new dimension to the tensor.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[:, None].shape)
|
||||
```
|
||||
|
||||
NOTE: Out-of-bounds indexing results in a value of `0`.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([1, 2, 3])
|
||||
print(t[Tensor([4, 3, 2])].numpy())
|
||||
```
|
||||
"""
|
||||
return self._getitem(indices)
|
||||
|
||||
def _getitem(self, indices, v=None) -> Self:
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
|
|
|||
|
|
@ -817,46 +817,6 @@ class Tensor(RandMixin):
|
|||
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
|
||||
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
|
||||
|
||||
def __getitem__(self, indices) -> Tensor:
|
||||
"""
|
||||
Retrieves a sub-tensor using indexing.
|
||||
|
||||
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
|
||||
|
||||
Examples:
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(12).reshape(3, 4)
|
||||
print(t.numpy())
|
||||
```
|
||||
|
||||
- Int Indexing: Select an element or sub-tensor using integers for each dimension.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[1, 2].numpy())
|
||||
```
|
||||
|
||||
- Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[0:2, ::2].numpy())
|
||||
```
|
||||
|
||||
- Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
|
||||
```
|
||||
|
||||
- `None` Indexing: Add a new dimension to the tensor.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t[:, None].shape)
|
||||
```
|
||||
|
||||
NOTE: Out-of-bounds indexing results in a value of `0`.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([1, 2, 3])
|
||||
print(t[Tensor([4, 3, 2])].numpy())
|
||||
```
|
||||
"""
|
||||
return super().__getitem__(indices)
|
||||
|
||||
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
|
||||
if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}")
|
||||
# raise if mutation would diverge from eager (allow only pure views of a realized buffer; exclude +=/-= RHS via v_uop/v_bw)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue