make _pool simpler (#13161)

* make _pool simpler

* just syntax

* more correct and smaller

* try this now

* Revert "try this now"

This reverts commit 607cdc2164.

* ONE_POOL
This commit is contained in:
George Hotz 2025-11-07 15:58:44 -08:00 committed by GitHub
commit 2413311289
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2100,22 +2100,22 @@ class Tensor(OpMixin):
noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):]
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
if getenv("ONE_POOL") or any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
# input size scaling factor to make sure shrink for stride is possible
f_ = [1 + int(resolve(o*s > (i - d*(k-1)))) for o,s,i,d,k in zip(o_,s_,i_,d_,k_)]
# # repeats such that we don't need padding
f_ = [smax(1, ceildiv(o*s - d, i)) for o,s,i,d in zip(o_,s_,i_,d_)]
# repeats such that we don't need padding
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
# handle dilation
x = x.shrink(tuple(noop + [(0,k*(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)])).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
x = x.shrink_to(noop + [k*(i*f+d) for k,i,d,f in zip(k_,i_,d_,f_)]).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
# handle stride
x = x.shrink(tuple(noop + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_,o_,s_)))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
x = x.shrink(tuple(noop + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_,o_)))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
x = x.shrink_to(noop + flatten((k,o*s) for k,o,s in zip(k_,o_,s_))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
x = x.shrink_to(noop + flatten((k,o,1) for k,o in zip(k_,o_))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
# permute to move reduce to the end
return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])
# TODO: once the shapetracker can optimize well, remove this alternative implementation
x = self.pad(tuple(noop + [(0, max(0,o*s-i)) for i,o,s in zip(i_,o_,s_)])).shrink(tuple(noop + [(0,o*s) for o,s in zip(o_,s_)]))
x = x.reshape(noop + flatten(((o,s) for o,s in zip(o_,s_))))
x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_))))
x = x.shrink_to(noop + flatten((o,k) for o,k in zip(o_,k_)))
return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]: