mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Refactor Conv2d/ConvTranspose2d into a single parent class (#1906)
* refactor Conv2d/ConvTranspose2d * raise in __call__ for the parent class * use ABC * drop ABC it's just syntactic sugar * use conv2d as base for the transposed version
This commit is contained in:
parent
97dc813329
commit
2201b46bce
1 changed files with 8 additions and 8 deletions
|
|
@ -43,7 +43,7 @@ class Conv2d:
|
|||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
||||
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
|
||||
self.weight = Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
self.weight = self.initialize_weight(out_channels, in_channels, groups)
|
||||
assert all_int(self.weight.shape), "does not support symbolic shape"
|
||||
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
|
||||
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
|
||||
|
|
@ -51,21 +51,21 @@ class Conv2d:
|
|||
def __call__(self, x):
|
||||
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
|
||||
def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
|
||||
def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
||||
return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias)
|
||||
|
||||
class ConvTranspose2d:
|
||||
class ConvTranspose2d(Conv2d):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
||||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
||||
self.stride, self.padding, self.output_padding, self.dilation, self.groups = stride, padding, output_padding, dilation, groups
|
||||
self.weight = Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
assert all_int(self.weight.shape), "does not support symbolic shape"
|
||||
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
|
||||
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
|
||||
self.output_padding = output_padding
|
||||
|
||||
def __call__(self, x):
|
||||
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
|
||||
def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
|
||||
class Linear:
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue