Refactor Conv2d/ConvTranspose2d into a single parent class (#1906)

* refactor Conv2d/ConvTranspose2d

* raise in __call__ for the parent class

* use ABC

* drop ABC it's just syntactic sugar

* use conv2d as base for the transposed version
This commit is contained in:
qazal 2023-09-24 09:23:41 +03:00 committed by GitHub
commit 2201b46bce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -43,7 +43,7 @@ class Conv2d:
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
self.weight = Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
self.weight = self.initialize_weight(out_channels, in_channels, groups)
assert all_int(self.weight.shape), "does not support symbolic shape"
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
@ -51,21 +51,21 @@ class Conv2d:
def __call__(self, x):
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias)
class ConvTranspose2d:
class ConvTranspose2d(Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
self.stride, self.padding, self.output_padding, self.dilation, self.groups = stride, padding, output_padding, dilation, groups
self.weight = Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
assert all_int(self.weight.shape), "does not support symbolic shape"
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.output_padding = output_padding
def __call__(self, x):
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
class Linear:
def __init__(self, in_features, out_features, bias=True):
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))