mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
bounty: Remove Tensor._pool alternative implementation and verify kernels remain the same (#13164)
This commit is contained in:
parent
ffb9e8396f
commit
b41541bc44
1 changed files with 11 additions and 17 deletions
|
|
@ -2100,23 +2100,17 @@ class Tensor(OpMixin):
|
|||
noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):]
|
||||
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
|
||||
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
|
||||
if getenv("ONE_POOL") or any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
|
||||
# input size scaling factor to make sure shrink for stride is possible
|
||||
f_ = [smax(1, ceildiv(o*s - d, i)) for o,s,i,d in zip(o_,s_,i_,d_)]
|
||||
# repeats such that we don't need padding
|
||||
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
|
||||
# handle dilation
|
||||
x = x.shrink_to(noop + [k*(i*f+d) for k,i,d,f in zip(k_,i_,d_,f_)]).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
|
||||
# handle stride
|
||||
x = x.shrink_to(noop + flatten((k,o*s) for k,o,s in zip(k_,o_,s_))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
|
||||
x = x.shrink_to(noop + flatten((k,o,1) for k,o in zip(k_,o_))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
|
||||
# permute to move reduce to the end
|
||||
return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])
|
||||
# TODO: once the shapetracker can optimize well, remove this alternative implementation
|
||||
x = self.pad(tuple(noop + [(0, max(0,o*s-i)) for i,o,s in zip(i_,o_,s_)])).shrink(tuple(noop + [(0,o*s) for o,s in zip(o_,s_)]))
|
||||
x = x.reshape(noop + flatten(((o,s) for o,s in zip(o_,s_))))
|
||||
x = x.shrink_to(noop + flatten((o,k) for o,k in zip(o_,k_)))
|
||||
return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
|
||||
# input size scaling factor to make sure shrink for stride is possible
|
||||
f_ = [smax(1, ceildiv(o*s - d, i)) for o,s,i,d in zip(o_,s_,i_,d_)]
|
||||
# repeats such that we don't need padding
|
||||
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
|
||||
# handle dilation
|
||||
x = x.shrink_to(noop + [k*(i*f+d) for k,i,d,f in zip(k_,i_,d_,f_)]).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
|
||||
# handle stride
|
||||
x = x.shrink_to(noop + flatten((k,o*s) for k,o,s in zip(k_,o_,s_))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
|
||||
x = x.shrink_to(noop + flatten((k,o,1) for k,o in zip(k_,o_))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
|
||||
# permute to move reduce to the end
|
||||
return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])
|
||||
|
||||
def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]:
|
||||
if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue