mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
6 commits
master
...
simple_poo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c4971f345 | ||
|
|
a5fd297df5 | ||
|
|
607cdc2164 | ||
|
|
d62c733b3f | ||
|
|
a6e1cc3c65 | ||
|
|
6f2dd96df9 |
1 changed files with 7 additions and 7 deletions
|
|
@ -2100,22 +2100,22 @@ class Tensor(OpMixin):
|
||||||
noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):]
|
noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):]
|
||||||
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
|
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
|
||||||
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
|
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
|
||||||
if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
|
if getenv("ONE_POOL") or any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
|
||||||
# input size scaling factor to make sure shrink for stride is possible
|
# input size scaling factor to make sure shrink for stride is possible
|
||||||
f_ = [1 + int(resolve(o*s > (i - d*(k-1)))) for o,s,i,d,k in zip(o_,s_,i_,d_,k_)]
|
f_ = [smax(1, ceildiv(o*s - d, i)) for o,s,i,d in zip(o_,s_,i_,d_)]
|
||||||
# # repeats such that we don't need padding
|
# repeats such that we don't need padding
|
||||||
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
|
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
|
||||||
# handle dilation
|
# handle dilation
|
||||||
x = x.shrink(tuple(noop + [(0,k*(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)])).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
|
x = x.shrink_to(noop + [k*(i*f+d) for k,i,d,f in zip(k_,i_,d_,f_)]).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
|
||||||
# handle stride
|
# handle stride
|
||||||
x = x.shrink(tuple(noop + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_,o_,s_)))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
|
x = x.shrink_to(noop + flatten((k,o*s) for k,o,s in zip(k_,o_,s_))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
|
||||||
x = x.shrink(tuple(noop + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_,o_)))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
|
x = x.shrink_to(noop + flatten((k,o,1) for k,o in zip(k_,o_))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
|
||||||
# permute to move reduce to the end
|
# permute to move reduce to the end
|
||||||
return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])
|
return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])
|
||||||
# TODO: once the shapetracker can optimize well, remove this alternative implementation
|
# TODO: once the shapetracker can optimize well, remove this alternative implementation
|
||||||
x = self.pad(tuple(noop + [(0, max(0,o*s-i)) for i,o,s in zip(i_,o_,s_)])).shrink(tuple(noop + [(0,o*s) for o,s in zip(o_,s_)]))
|
x = self.pad(tuple(noop + [(0, max(0,o*s-i)) for i,o,s in zip(i_,o_,s_)])).shrink(tuple(noop + [(0,o*s) for o,s in zip(o_,s_)]))
|
||||||
x = x.reshape(noop + flatten(((o,s) for o,s in zip(o_,s_))))
|
x = x.reshape(noop + flatten(((o,s) for o,s in zip(o_,s_))))
|
||||||
x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_))))
|
x = x.shrink_to(noop + flatten((o,k) for o,k in zip(o_,k_)))
|
||||||
return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
|
return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
|
||||||
|
|
||||||
def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]:
|
def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue