support all _pool2d. conv will eventually be an hlop

This commit is contained in:
George Hotz 2023-02-23 08:19:47 -08:00
commit fd6082dcef
6 changed files with 39 additions and 11 deletions

View file

@ -149,7 +149,7 @@ def get_run_onnx(onnx_model):
assert opt['kernel_shape'] == opt['strides'] or opt['strides'] == (1,1)
ret = inp[0].avg_pool2d(opt['kernel_shape'])
elif n.op_type == "MaxPool":
assert opt['kernel_shape'] == opt['strides'], f"kernel_shape and stride mismatch {opt}"
#assert opt['kernel_shape'] == opt['strides'], f"kernel_shape and stride mismatch {opt}"
#opt['kernel_shape'] = opt['strides']
# TODO: this is untested and probably wrong
ret = inp[0].pad2d(opt['pads'])

View file

@ -445,6 +445,18 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride),
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride))
def test_maxpool2d_unit_stride(self):
helper_test_op([(32,2,110,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=1),
lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=1))
def test_maxpool2d_smaller_stride(self):
for stride in [(2,3), (3,2), 2, 3]:
with self.subTest(stride=stride):
helper_test_op([(32,2,110,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=stride),
lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=stride))
def test_avgpool2d(self):
shape = (32,2,111,28)
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1), shape[2:]]:

View file

@ -245,6 +245,9 @@ class CLASTKernel(ASTKernel):
self.output_shape = list(self.sts[0].shape[:self.first_reduce]) + self.group_for_reduce
if DEBUG >= 3:
print("output shape", self.output_shape)
if DEBUG >= 4:
for b in self.bufs:
print(b.st)
self.printbufs("new:")
self.bufs_to_delete : Set[int] = set()

View file

@ -142,7 +142,7 @@ class Permute(Function):
def backward(self, grad_output):
return grad_output.movement_op(MovementOps.PERMUTE, tuple(argsort(self.input_order)))
# TODO: merge Slice and Flip into Stride with the 3 arguments
# TODO: merge Slice and Flip into Stride with the 3 arguments. or don't, __getitem__ should support strides as an hlop
class Slice(Function):
def forward(self, x, arg=None):
self.narg = tuple((0-p[0], x.shape[i]-p[0]) for i,p in enumerate(arg))

View file

@ -86,7 +86,7 @@ class CLProgram:
def __call__(self, *args) -> cl.Event:
if DEBUG >= 4: print(args[0], args[1], self.prg)
# print the PTX for NVIDIA. TODO: probably broken for everything else
if DEBUG >= 5: print(self.clprogram.get_info(cl.program_info.BINARIES)[0].decode('utf-8'))
if DEBUG >= 5 and not OSX: print(self.clprogram.get_info(cl.program_info.BINARIES)[0].decode('utf-8'))
e = self.clprg(CL().cl_queue, *args)
if DEBUG >= 2:
assert CL.cl_queue is not None

View file

@ -290,14 +290,27 @@ class Tensor:
_mask : np.ndarray = np.asarray(Tensor._rng.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype)
return self * Tensor(_mask, requires_grad=False, device=self.device) * (1/(1.0 - p))
# TODO: support arbitrary strides
def _pool2d(self, py, px, sy, sx):
if py > sy or px > sx: raise NotImplementedError("pool2d doesn't support kernel_size > stride")
xup = self.slice(((0, self.shape[0]), (0, self.shape[1]), (0, (self.shape[2]+(sy-py))//sy*sy), (0, (self.shape[3]+(sx-px))//sx*sx)))
return xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//sy, sy, xup.shape[3]//sx, sx))[:, :, :, :py, :, :px]
def _pool2d(self, ky, kx, sy, sx):
if ky > sy or kx > sx:
# NOTE: this gives me hope for an hlop conv. need to optimize this to one view
bs,c,iy,ix = self.shape
oy = (iy - (ky-1) - 1)//sy + 1
ox = (ix - (kx-1) - 1)//sx + 1
# duplicate the inputs for each of the kernels
xup = self.reshape(bs, c, 1, iy, 1, ix).expand(bs, c, ky, iy, kx, ix).reshape(bs, c, ky*iy, kx*ix)
# slide by 1 (this is dilation?)
xup = xup.slice(((0,bs), (0,c), (0,ky*(iy+1)), (0,kx*(ix+1))))
xup = xup.reshape(bs, c, ky, iy+1, kx, ix+1)
xup = xup.slice(((0,bs), (0,c), (0,ky), (0,oy*sy), (0,kx), (0,ox*sx)))
# handle stride, and permute to move reduce to the end
xup = xup.reshape(bs, c, ky, oy, sy, kx, ox, sx)[:, :, :, :, 0, :, :, 0]
return xup.permute(0, 1, 3, 5, 2, 4)
# TODO: once the shapetracker can optimize, remove this alternative implementation
xup = self.slice(((0, self.shape[0]), (0, self.shape[1]), (0, (self.shape[2]+(sy-ky))//sy*sy), (0, (self.shape[3]+(sx-kx))//sx*sx)))
return xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//sy, sy, xup.shape[3]//sx, sx))[:, :, :, :ky, :, :kx].permute(0, 1, 2, 4, 3, 5)
def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool2d(*make_pair(kernel_size), *make_pair(stride if stride is not None else kernel_size)).mean(axis=(3,5))
def max_pool2d(self, kernel_size=(2,2), stride=None): return self._pool2d(*make_pair(kernel_size), *make_pair(stride if stride is not None else kernel_size)).max(axis=(3,5))
def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool2d(*make_pair(kernel_size), *make_pair(stride if stride is not None else kernel_size)).mean(axis=(4,5))
def max_pool2d(self, kernel_size=(2,2), stride=None): return self._pool2d(*make_pair(kernel_size), *make_pair(stride if stride is not None else kernel_size)).max(axis=(4,5))
def conv2d(self, weight, bias=None, **kwargs):
ret = mlops.Conv2D.apply(self, weight, **kwargs)