mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
support bias in conv like linear
This commit is contained in:
parent
bd21304e3c
commit
908db3bdea
2 changed files with 6 additions and 2 deletions
|
|
@ -44,8 +44,8 @@ class MBConvBlock:
|
|||
# has_se
|
||||
if self.has_se:
|
||||
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
|
||||
x_squeezed = x_squeezed.conv2d(self._se_reduce).add(self._se_reduce_bias.reshape(shape=[1, -1, 1, 1])).swish()
|
||||
x_squeezed = x_squeezed.conv2d(self._se_expand).add(self._se_expand_bias.reshape(shape=[1, -1, 1, 1]))
|
||||
x_squeezed = x_squeezed.conv2d(self._se_reduce, self._se_reduce_bias).swish()
|
||||
x_squeezed = x_squeezed.conv2d(self._se_expand, self._se_expand_bias)
|
||||
x = x.mul(x_squeezed.sigmoid())
|
||||
|
||||
x = self._bn2(x.conv2d(self._project_conv))
|
||||
|
|
|
|||
|
|
@ -308,6 +308,10 @@ class Tensor:
|
|||
def max_pool2d(self, kernel_size=(2,2)):
|
||||
return self._pool2d(*kernel_size).max(axis=(3,5))
|
||||
|
||||
def conv2d(self, weight, bias=None, stride=1, groups=1):
|
||||
ret = self._conv2d(weight, stride=stride, groups=groups)
|
||||
return ret if bias is None else ret.add(bias.reshape(shape=[1, -1, 1, 1]))
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
def linear(self, weight, bias):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue