avg/max pool strides

This commit is contained in:
George Hotz 2023-02-22 18:00:48 -08:00
commit c8d89eb20e
2 changed files with 13 additions and 5 deletions

View file

@ -438,6 +438,13 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz),
lambda x: Tensor.max_pool2d(x, kernel_size=ksz))
def test_maxpool2d_bigger_stride(self):
for stride in [(2,3), (3,2), 2, 3]:
with self.subTest(stride=stride):
helper_test_op([(32,2,110,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride),
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride))
def test_avgpool2d(self):
shape = (32,2,111,28)
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1), shape[2:]]:

View file

@ -286,12 +286,13 @@ class Tensor:
return self * Tensor(_mask, requires_grad=False, device=self.device) * (1/(1.0 - p))
# TODO: support arbitrary strides
def _pool2d(self, py, px):
xup = self[:, :, :self.shape[2]-self.shape[2]%py, :self.shape[3]-self.shape[3]%px] if (self.shape[2]%py != 0) or (self.shape[3]%px != 0) else self
return xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//py, py, xup.shape[3]//px, px))
def _pool2d(self, py, px, sy, sx):
if py > sy or px > sx: raise NotImplementedError("pool2d doesn't support kernel_size > stride")
xup = self.slice(((0, self.shape[0]), (0, self.shape[1]), (0, (self.shape[2]+(sy-py))//sy*sy), (0, (self.shape[3]+(sx-px))//sx*sx)))
return xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//sy, sy, xup.shape[3]//sx, sx))[:, :, :, :py, :, :px]
def avg_pool2d(self, kernel_size=(2,2)): return self._pool2d(*make_pair(kernel_size)).mean(axis=(3,5))
def max_pool2d(self, kernel_size=(2,2)): return self._pool2d(*make_pair(kernel_size)).max(axis=(3,5))
def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool2d(*make_pair(kernel_size), *make_pair(stride if stride is not None else kernel_size)).mean(axis=(3,5))
def max_pool2d(self, kernel_size=(2,2), stride=None): return self._pool2d(*make_pair(kernel_size), *make_pair(stride if stride is not None else kernel_size)).max(axis=(3,5))
def conv2d(self, weight, bias=None, **kwargs):
ret = mlops.Conv2D.apply(self, weight, **kwargs)