mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c43f8edc3e | ||
|
|
f7a9805dcf |
3 changed files with 19 additions and 6 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# mixins add syntactic sugar to Tensor and UOp
|
||||
import functools
|
||||
from typing import TypeAlias, TYPE_CHECKING, Self
|
||||
from typing import TypeAlias, TYPE_CHECKING, Self, Sequence
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.helpers import prod, argfix, flatten, dedup, make_tuple, ceildiv
|
||||
from tinygrad.uop.ops import resolve, smax
|
||||
|
|
@ -16,6 +16,10 @@ def _align_left(*shapes: tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
|
|||
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
|
||||
|
||||
|
||||
# `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
|
||||
def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: return tuple(zip(padding[-2::-2], padding[::-2]))
|
||||
|
||||
|
||||
class MovementMixin:
|
||||
# required to implement
|
||||
def _mop(self, op: Ops, arg) -> Self:
|
||||
|
|
@ -374,3 +378,14 @@ class MovementMixin:
|
|||
x = x.shrink_to(noop + flatten((k, o, 1) for k, o in zip(k_, o_))).reshape(noop + flatten((k, o) for k, o in zip(k_, o_)))
|
||||
# permute to move reduce to the end
|
||||
return x.permute(*range(len(noop)), *[len(noop) + i * 2 + 1 for i in range(len(i_))], *[len(noop) + i * 2 for i in range(len(i_))])
|
||||
|
||||
# **** pad ****
|
||||
|
||||
def pad(self, padding:Sequence[tuple[sint, sint]|None]) -> Self:
|
||||
"""
|
||||
Returns a tensor with constant zero padding applied based on the input `padding`.
|
||||
`padding` must have the same length as `self.ndim`. For each axis, padding can be `None` (no padding) or a tuple `(before, after)`.
|
||||
"""
|
||||
pX = tuple((0,0) if p is None else p for p in padding)
|
||||
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
|
||||
return self._mop(Ops.PAD, pX)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, p
|
|||
from tinygrad.helpers import suppress_finalizing, disable_gc
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.mixin import OpMixin
|
||||
from tinygrad.mixin.movement import _align_left
|
||||
from tinygrad.mixin.movement import _align_left, _flat_to_grouped
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable
|
||||
from tinygrad.engine.schedule import ScheduleItem, complete_create_schedule_with_vars
|
||||
from tinygrad.device import Device, Buffer
|
||||
|
|
@ -91,8 +91,6 @@ def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, .
|
|||
# select from values for each True element in mask else select from target
|
||||
return mask.where(values, target)
|
||||
|
||||
# `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
|
||||
def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: return tuple(zip(padding[-2::-2], padding[::-2]))
|
||||
|
||||
ReductionStr = Literal["mean", "sum", "none"]
|
||||
|
||||
|
|
@ -1069,7 +1067,7 @@ class Tensor(OpMixin):
|
|||
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
|
||||
if mode == "constant":
|
||||
def _constant(x:Tensor,px,v) -> Tensor:
|
||||
return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
|
||||
return x._mop(Ops.PAD, px) if v == 0 else (x._mop(Ops.PAD, px)+Tensor.ones_like(x)._mop(Ops.PAD, px).where(0,v))
|
||||
return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
|
||||
_constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
|
|
|
|||
|
|
@ -590,7 +590,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
#def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
|
||||
#def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
|
||||
#def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
|
||||
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True)
|
||||
#def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True) # now in MovementMixin
|
||||
|
||||
# in these two, we have custom logic to check if they are a no-op
|
||||
#def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg, same_shape_noop=False) if arg != tuple(range(len(self.shape))) else self
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue