mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
minor cleanup of image (#2820)
use transpose for transpose instead of permute, and use pad for pad instead of slice
This commit is contained in:
parent
959d9cfed4
commit
21ec7e09f6
1 changed files with 7 additions and 12 deletions
|
|
@ -10,19 +10,14 @@ def image_dot(self, w):
|
|||
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
|
||||
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
|
||||
cin, cout = w.shape[-2], w.shape[-1]
|
||||
out_shape_t = self.shape[0:-2] + (cout,-1)
|
||||
if len(self.shape) > 1:
|
||||
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
|
||||
else:
|
||||
order, out_shape_t = (0,), (cout, )
|
||||
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
|
||||
out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, )
|
||||
|
||||
# NOTE: with NHWC we can remove the transposes
|
||||
# bs x groups*cin x H x W
|
||||
cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
|
||||
cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
|
||||
# groups*cout x cin x H, W
|
||||
cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
|
||||
return image_conv2d(cx, cw, groups=groups).reshape(shape=out_shape_t).permute(order=order)
|
||||
cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
|
||||
return image_conv2d(cx, cw, groups=groups).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
|
||||
|
||||
def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
|
||||
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
||||
|
|
@ -35,8 +30,8 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
|
|||
if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
|
||||
x = x.reshape(bs, groups, cin, iy, ix) # do this always?
|
||||
added_input_channels = 4 - (cin % 4)
|
||||
w = w.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(w.shape))))
|
||||
x = x.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(x.shape))))
|
||||
w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim)))
|
||||
x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim)))
|
||||
cin = cin + added_input_channels
|
||||
x = x.reshape(bs, groups*cin, iy, ix)
|
||||
|
||||
|
|
@ -46,7 +41,7 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
|
|||
added_output_channels = 4 - (rcout % 4)
|
||||
rcout += added_output_channels
|
||||
cout = groups * rcout
|
||||
w = w.slice(tuple((0, rcout) if i == 1 else (0, s) for i,s in enumerate(w.shape)))
|
||||
w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim)))
|
||||
|
||||
# packed (note: flipping bs and iy would make the auto-padding work)
|
||||
x = x.permute(0,2,3,1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue