mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
master
...
reshape_tr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1255eeec6d | ||
|
|
a1f88fea37 |
4 changed files with 45 additions and 33 deletions
|
|
@ -88,15 +88,15 @@ def hand_spec_kernel3():
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
# GLOBAL -> LOCAL (As, Bs)
|
# GLOBAL -> LOCAL (As, Bs)
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
b = b.reshape((N // BLOCK_K, BLOCK_K,
|
b = b.reshape(N // BLOCK_K, BLOCK_K,
|
||||||
N // BLOCK_N, BLOCK_N))
|
N // BLOCK_N, BLOCK_N)
|
||||||
i = UOp.range(BLOCK_N * BLOCK_K // THREADS_PER_BLOCK, 1)
|
i = UOp.range(BLOCK_N * BLOCK_K // THREADS_PER_BLOCK, 1)
|
||||||
index_x = tid % BLOCK_N
|
index_x = tid % BLOCK_N
|
||||||
index_y = (tid // BLOCK_N) + (THREADS_PER_BLOCK // BLOCK_N) * i
|
index_y = (tid // BLOCK_N) + (THREADS_PER_BLOCK // BLOCK_N) * i
|
||||||
Bs_store = Bs[index_y, index_x].store(b[k_tile_range, index_y, blockIdx_x, index_x]).end(i)
|
Bs_store = Bs[index_y, index_x].store(b[k_tile_range, index_y, blockIdx_x, index_x]).end(i)
|
||||||
|
|
||||||
a = a.reshape((N // BLOCK_M, BLOCK_M,
|
a = a.reshape(N // BLOCK_M, BLOCK_M,
|
||||||
N // BLOCK_K, BLOCK_K))
|
N // BLOCK_K, BLOCK_K)
|
||||||
i = UOp.range(BLOCK_M * BLOCK_K // THREADS_PER_BLOCK, 2)
|
i = UOp.range(BLOCK_M * BLOCK_K // THREADS_PER_BLOCK, 2)
|
||||||
index_x = tid % BLOCK_K
|
index_x = tid % BLOCK_K
|
||||||
index_y = (tid // BLOCK_K) + (THREADS_PER_BLOCK // BLOCK_K) * i
|
index_y = (tid // BLOCK_K) + (THREADS_PER_BLOCK // BLOCK_K) * i
|
||||||
|
|
@ -113,12 +113,12 @@ def hand_spec_kernel3():
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
# LOCAL -> REG (per-wave tiles)
|
# LOCAL -> REG (per-wave tiles)
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
Bs_view = Bs.reshape((BLOCK_K, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN))
|
Bs_view = Bs.reshape(BLOCK_K, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)
|
||||||
iterWaveN = UOp.range(ITERS_PER_WAVE_N, 4)
|
iterWaveN = UOp.range(ITERS_PER_WAVE_N, 4)
|
||||||
i = UOp.range(TN, 5)
|
i = UOp.range(TN, 5)
|
||||||
B_row = B_row[iterWaveN, i].set(Bs_view[k, waveIdx, iterWaveN, idxInWave, i], end=(iterWaveN, i))
|
B_row = B_row[iterWaveN, i].set(Bs_view[k, waveIdx, iterWaveN, idxInWave, i], end=(iterWaveN, i))
|
||||||
|
|
||||||
As_view = As.reshape((BLOCK_K, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM))
|
As_view = As.reshape(BLOCK_K, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM)
|
||||||
iterWaveM = UOp.range(ITERS_PER_WAVE_M, 6)
|
iterWaveM = UOp.range(ITERS_PER_WAVE_M, 6)
|
||||||
i = UOp.range(TM, 7)
|
i = UOp.range(TM, 7)
|
||||||
A_col = A_col[iterWaveM, i].set(As_view[k, waveIdy, iterWaveM, idyInWave, i], end=(iterWaveM, i))
|
A_col = A_col[iterWaveM, i].set(As_view[k, waveIdy, iterWaveM, idyInWave, i], end=(iterWaveM, i))
|
||||||
|
|
@ -139,8 +139,8 @@ def hand_spec_kernel3():
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
# REG -> GLOBAL (epilogue)
|
# REG -> GLOBAL (epilogue)
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
c = c.reshape((N//BLOCK_M, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM,
|
c = c.reshape(N//BLOCK_M, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM,
|
||||||
N//BLOCK_N, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN))
|
N//BLOCK_N, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)
|
||||||
iterWaveM = UOp.range(ITERS_PER_WAVE_M, 1000)
|
iterWaveM = UOp.range(ITERS_PER_WAVE_M, 1000)
|
||||||
yt = UOp.range(TM, 1001)
|
yt = UOp.range(TM, 1001)
|
||||||
iterWaveN = UOp.range(ITERS_PER_WAVE_N, 1002)
|
iterWaveN = UOp.range(ITERS_PER_WAVE_N, 1002)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, p
|
||||||
from tinygrad.helpers import suppress_finalizing
|
from tinygrad.helpers import suppress_finalizing
|
||||||
from tinygrad.gradient import compute_gradient
|
from tinygrad.gradient import compute_gradient
|
||||||
from tinygrad.uop.mathtraits import MathTrait
|
from tinygrad.uop.mathtraits import MathTrait
|
||||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, srender
|
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop
|
||||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||||
from tinygrad.device import Device, Buffer
|
from tinygrad.device import Device, Buffer
|
||||||
from tinygrad.engine.realize import run_schedule
|
from tinygrad.engine.realize import run_schedule
|
||||||
|
|
@ -1038,28 +1038,7 @@ class Tensor(MathTrait):
|
||||||
|
|
||||||
# ***** movement low level ops *****
|
# ***** movement low level ops *****
|
||||||
|
|
||||||
def view(self, shape:tuple[sint, ...], *args) -> Tensor:
|
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
|
||||||
"""`.view` is an alias for `.reshape`."""
|
|
||||||
return self.reshape(shape, *args)
|
|
||||||
|
|
||||||
def reshape(self, shape, *args) -> Tensor:
|
|
||||||
"""
|
|
||||||
Returns a tensor with the same data as the original tensor but with a different shape.
|
|
||||||
`shape` can be passed as a tuple or as separate arguments.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = Tensor.arange(6)
|
|
||||||
print(t.reshape(2, 3).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
# resolve None and args
|
|
||||||
new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
|
|
||||||
# resolve -1
|
|
||||||
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
|
||||||
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
|
||||||
if resolve(prod(self.shape) != prod(new_shape), True):
|
|
||||||
raise ValueError(f"size mismatch, can't reshape ({', '.join(srender(d) for d in self.shape)}) -> ({', '.join(srender(d) for d in new_shape)})")
|
|
||||||
return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self
|
|
||||||
|
|
||||||
def expand(self, shape, *args) -> Tensor:
|
def expand(self, shape, *args) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,10 @@
|
||||||
from typing import TypeVar
|
from typing import TypeVar, TypeAlias, TYPE_CHECKING
|
||||||
from tinygrad.uop import Ops
|
from tinygrad.uop import Ops
|
||||||
from tinygrad.dtype import dtypes, ConstType
|
from tinygrad.dtype import dtypes, ConstType
|
||||||
|
from tinygrad.helpers import prod, argfix
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tinygrad.uop.ops import UOp
|
||||||
|
sint:TypeAlias = UOp|int
|
||||||
|
|
||||||
TMT = TypeVar("TMT", bound="MathTrait")
|
TMT = TypeVar("TMT", bound="MathTrait")
|
||||||
class MathTrait:
|
class MathTrait:
|
||||||
|
|
@ -171,3 +175,32 @@ class MathTrait:
|
||||||
def exp2(self): return self.alu(Ops.EXP2)
|
def exp2(self): return self.alu(Ops.EXP2)
|
||||||
def pow(self:TMT, x:TMT|ConstType): return self.alu(Ops.POW, self.ufix(x))
|
def pow(self:TMT, x:TMT|ConstType): return self.alu(Ops.POW, self.ufix(x))
|
||||||
def __pow__(self:TMT, x:TMT|ConstType): return self.pow(x)
|
def __pow__(self:TMT, x:TMT|ConstType): return self.pow(x)
|
||||||
|
|
||||||
|
# **** movement ops ****
|
||||||
|
|
||||||
|
# required to implement
|
||||||
|
def _mop(self:TMT, op:Ops, arg) -> TMT: raise NotImplementedError
|
||||||
|
@property
|
||||||
|
def shape(self) -> tuple["sint", ...]: raise NotImplementedError
|
||||||
|
|
||||||
|
def view(self:TMT, shape, *args) -> TMT:
|
||||||
|
"""`.view` is an alias for `.reshape`."""
|
||||||
|
return self.reshape(shape, *args)
|
||||||
|
|
||||||
|
def reshape(self:TMT, shape, *args) -> TMT:
|
||||||
|
"""
|
||||||
|
Returns a tensor with the same data as the original tensor but with a different shape.
|
||||||
|
`shape` can be passed as a tuple or as separate arguments.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor.arange(6)
|
||||||
|
print(t.reshape(2, 3).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# resolve None and args
|
||||||
|
new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
|
||||||
|
# resolve -1
|
||||||
|
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
||||||
|
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
||||||
|
if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
|
||||||
|
return self._mop(Ops.RESHAPE, arg=new_shape) if new_shape != self.shape else self
|
||||||
|
|
|
||||||
|
|
@ -533,7 +533,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||||
|
|
||||||
# in these four, if the shape doesn't change we can return self
|
# in these four, if the shape doesn't change we can return self
|
||||||
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False)
|
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False)
|
||||||
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
|
#def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
|
||||||
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
|
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
|
||||||
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
|
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
|
||||||
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True)
|
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue