Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
c43f8edc3e small diff 2025-12-16 17:09:47 -04:00
George Hotz
f7a9805dcf move pad to mixin 2025-12-16 17:02:45 -04:00
3 changed files with 19 additions and 6 deletions

View file

@ -1,6 +1,6 @@
# mixins add syntactic sugar to Tensor and UOp
import functools
from typing import TypeAlias, TYPE_CHECKING, Self
from typing import TypeAlias, TYPE_CHECKING, Self, Sequence
from tinygrad.uop import Ops
from tinygrad.helpers import prod, argfix, flatten, dedup, make_tuple, ceildiv
from tinygrad.uop.ops import resolve, smax
@ -16,6 +16,10 @@ def _align_left(*shapes: tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
# `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: return tuple(zip(padding[-2::-2], padding[::-2]))
class MovementMixin:
# required to implement
def _mop(self, op: Ops, arg) -> Self:
@ -374,3 +378,14 @@ class MovementMixin:
x = x.shrink_to(noop + flatten((k, o, 1) for k, o in zip(k_, o_))).reshape(noop + flatten((k, o) for k, o in zip(k_, o_)))
# permute to move reduce to the end
return x.permute(*range(len(noop)), *[len(noop) + i * 2 + 1 for i in range(len(i_))], *[len(noop) + i * 2 for i in range(len(i_))])
# **** pad ****
def pad(self, padding:Sequence[tuple[sint, sint]|None]) -> Self:
"""
Returns a tensor with constant zero padding applied based on the input `padding`.
`padding` must have the same length as `self.ndim`. For each axis, padding can be `None` (no padding) or a tuple `(before, after)`.
"""
pX = tuple((0,0) if p is None else p for p in padding)
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
return self._mop(Ops.PAD, pX)

View file

@ -10,7 +10,7 @@ from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, p
from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin
from tinygrad.mixin.movement import _align_left
from tinygrad.mixin.movement import _align_left, _flat_to_grouped
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable
from tinygrad.engine.schedule import ScheduleItem, complete_create_schedule_with_vars
from tinygrad.device import Device, Buffer
@ -91,8 +91,6 @@ def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, .
# select from values for each True element in mask else select from target
return mask.where(values, target)
# `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: return tuple(zip(padding[-2::-2], padding[::-2]))
ReductionStr = Literal["mean", "sum", "none"]
@ -1069,7 +1067,7 @@ class Tensor(OpMixin):
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
if mode == "constant":
def _constant(x:Tensor,px,v) -> Tensor:
return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
return x._mop(Ops.PAD, px) if v == 0 else (x._mop(Ops.PAD, px)+Tensor.ones_like(x)._mop(Ops.PAD, px).where(0,v))
return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
_constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"

View file

@ -590,7 +590,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
#def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
#def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
#def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True)
#def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True) # now in MovementMixin
# in these two, we have custom logic to check if they are a no-op
#def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg, same_shape_noop=False) if arg != tuple(range(len(self.shape))) else self