move _pool to MovementMixins (#13257)

This commit is contained in:
George Hotz 2025-11-13 07:28:52 -08:00 committed by GitHub
commit 5efa727b83
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 26 additions and 21 deletions

View file

@ -2,7 +2,8 @@
import functools
from typing import TypeAlias, TYPE_CHECKING, Self
from tinygrad.uop import Ops
from tinygrad.helpers import prod, argfix, flatten, dedup
from tinygrad.helpers import prod, argfix, flatten, dedup, make_tuple, ceildiv
from tinygrad.uop.ops import resolve, smax
if TYPE_CHECKING: from tinygrad.uop.ops import UOp
sint: TypeAlias = "UOp | int"
@ -326,3 +327,24 @@ class MovementMixin:
expanded_shape = flatten([[s] if r == 1 else [r, s] for r,s in zip(repeats, base_shape)])
final_shape = [r*s for r,s in zip(repeats, base_shape)]
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
# **** pool level ****
def _pool(self, k_:tuple[sint, ...], stride:int|tuple[int, ...]=1, dilation:int|tuple[int, ...]=1) -> Self:
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):]
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
# input size scaling factor to make sure shrink for stride is possible
f_ = [smax(1, ceildiv(o*s - d, i)) for o,s,i,d in zip(o_,s_,i_,d_)]
# repeats such that we don't need padding
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
# handle dilation
x = x.shrink_to(noop + [k*(i*f+d) for k,i,d,f in zip(k_,i_,d_,f_)]).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
# handle stride
x = x.shrink_to(noop + flatten((k,o*s) for k,o,s in zip(k_,o_,s_))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
x = x.shrink_to(noop + flatten((k,o,1) for k,o in zip(k_,o_))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
# permute to move reduce to the end
return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])

View file

@ -2093,25 +2093,6 @@ class Tensor(OpMixin):
# ***** processing ops *****
def _pool(self, k_:tuple[sint, ...], stride:int|tuple[int, ...]=1, dilation:int|tuple[int, ...]=1) -> Tensor:
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):]
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
# input size scaling factor to make sure shrink for stride is possible
f_ = [smax(1, ceildiv(o*s - d, i)) for o,s,i,d in zip(o_,s_,i_,d_)]
# repeats such that we don't need padding
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
# handle dilation
x = x.shrink_to(noop + [k*(i*f+d) for k,i,d,f in zip(k_,i_,d_,f_)]).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
# handle stride
x = x.shrink_to(noop + flatten((k,o*s) for k,o,s in zip(k_,o_,s_))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
x = x.shrink_to(noop + flatten((k,o,1) for k,o in zip(k_,o_))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
# permute to move reduce to the end
return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])
def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]:
if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims):
raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.")

View file

@ -4,7 +4,6 @@ import sys, time, functools, itertools, math, operator, hashlib, os, types, pick
from dataclasses import dataclass
from enum import Enum, auto
from tinygrad.uop import Ops, GroupOp
from tinygrad.mixin import OpMixin
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CI
@ -107,6 +106,9 @@ class recursive_property(property):
s.__dict__[self.nm] = val = self.fxn(s)
return val
# we import this late so we can use resolve/smax in mixins
from tinygrad.mixin import OpMixin
# NOTE: this should be frozen, but frozen is slower
@dataclass(eq=False, slots=True)
class UOp(OpMixin, metaclass=UOpMetaClass):