mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
move _pool to MovementMixins (#13257)
This commit is contained in:
parent
bcdfc109b5
commit
5efa727b83
3 changed files with 26 additions and 21 deletions
|
|
@ -2,7 +2,8 @@
|
|||
import functools
|
||||
from typing import TypeAlias, TYPE_CHECKING, Self
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.helpers import prod, argfix, flatten, dedup
|
||||
from tinygrad.helpers import prod, argfix, flatten, dedup, make_tuple, ceildiv
|
||||
from tinygrad.uop.ops import resolve, smax
|
||||
if TYPE_CHECKING: from tinygrad.uop.ops import UOp
|
||||
sint: TypeAlias = "UOp | int"
|
||||
|
||||
|
|
@ -326,3 +327,24 @@ class MovementMixin:
|
|||
expanded_shape = flatten([[s] if r == 1 else [r, s] for r,s in zip(repeats, base_shape)])
|
||||
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
||||
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
|
||||
|
||||
# **** pool level ****
|
||||
|
||||
def _pool(self, k_:tuple[sint, ...], stride:int|tuple[int, ...]=1, dilation:int|tuple[int, ...]=1) -> Self:
|
||||
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
||||
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
|
||||
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
||||
noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):]
|
||||
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
|
||||
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
|
||||
# input size scaling factor to make sure shrink for stride is possible
|
||||
f_ = [smax(1, ceildiv(o*s - d, i)) for o,s,i,d in zip(o_,s_,i_,d_)]
|
||||
# repeats such that we don't need padding
|
||||
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
|
||||
# handle dilation
|
||||
x = x.shrink_to(noop + [k*(i*f+d) for k,i,d,f in zip(k_,i_,d_,f_)]).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
|
||||
# handle stride
|
||||
x = x.shrink_to(noop + flatten((k,o*s) for k,o,s in zip(k_,o_,s_))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
|
||||
x = x.shrink_to(noop + flatten((k,o,1) for k,o in zip(k_,o_))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
|
||||
# permute to move reduce to the end
|
||||
return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])
|
||||
|
|
|
|||
|
|
@ -2093,25 +2093,6 @@ class Tensor(OpMixin):
|
|||
|
||||
# ***** processing ops *****
|
||||
|
||||
def _pool(self, k_:tuple[sint, ...], stride:int|tuple[int, ...]=1, dilation:int|tuple[int, ...]=1) -> Tensor:
|
||||
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
||||
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
|
||||
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
||||
noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):]
|
||||
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
|
||||
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
|
||||
# input size scaling factor to make sure shrink for stride is possible
|
||||
f_ = [smax(1, ceildiv(o*s - d, i)) for o,s,i,d in zip(o_,s_,i_,d_)]
|
||||
# repeats such that we don't need padding
|
||||
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
|
||||
# handle dilation
|
||||
x = x.shrink_to(noop + [k*(i*f+d) for k,i,d,f in zip(k_,i_,d_,f_)]).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
|
||||
# handle stride
|
||||
x = x.shrink_to(noop + flatten((k,o*s) for k,o,s in zip(k_,o_,s_))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
|
||||
x = x.shrink_to(noop + flatten((k,o,1) for k,o in zip(k_,o_))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
|
||||
# permute to move reduce to the end
|
||||
return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])
|
||||
|
||||
def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]:
|
||||
if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims):
|
||||
raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import sys, time, functools, itertools, math, operator, hashlib, os, types, pick
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.mixin import OpMixin
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
||||
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CI
|
||||
|
|
@ -107,6 +106,9 @@ class recursive_property(property):
|
|||
s.__dict__[self.nm] = val = self.fxn(s)
|
||||
return val
|
||||
|
||||
# we import this late so we can use resolve/smax in mixins
|
||||
from tinygrad.mixin import OpMixin
|
||||
|
||||
# NOTE: this should be frozen, but frozen is slower
|
||||
@dataclass(eq=False, slots=True)
|
||||
class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue