minor cleanup of image (#2820)

use transpose for transpose instead of permute, and use pad for pad instead of slice
This commit is contained in:
chenyu 2023-12-17 20:53:49 -05:00 committed by GitHub
commit 21ec7e09f6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -10,19 +10,14 @@ def image_dot(self, w):
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
cin, cout = w.shape[-2], w.shape[-1]
out_shape_t = self.shape[0:-2] + (cout,-1)
if len(self.shape) > 1:
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
else:
order, out_shape_t = (0,), (cout, )
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, )
# NOTE: with NHWC we can remove the transposes
# bs x groups*cin x H x W
cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
# groups*cout x cin x H, W
cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
return image_conv2d(cx, cw, groups=groups).reshape(shape=out_shape_t).permute(order=order)
cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
return image_conv2d(cx, cw, groups=groups).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
@ -35,8 +30,8 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
x = x.reshape(bs, groups, cin, iy, ix) # do this always?
added_input_channels = 4 - (cin % 4)
w = w.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(w.shape))))
x = x.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(x.shape))))
w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim)))
x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim)))
cin = cin + added_input_channels
x = x.reshape(bs, groups*cin, iy, ix)
@ -46,7 +41,7 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
added_output_channels = 4 - (rcout % 4)
rcout += added_output_channels
cout = groups * rcout
w = w.slice(tuple((0, rcout) if i == 1 else (0, s) for i,s in enumerate(w.shape)))
w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim)))
# packed (note: flipping bs and iy would make the auto-padding work)
x = x.permute(0,2,3,1)