minor nn cleanups (#8018)

use more .numel and .ndim
This commit is contained in:
chenyu 2024-12-03 12:34:52 -05:00 committed by GitHub
commit dacb1ff38a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -56,7 +56,7 @@ class BatchNorm:
# NOTE: wow, this is done all throughout training in most PyTorch models
if self.track_running_stats and Tensor.training:
self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * prod(x.shape)/(prod(x.shape)-x.shape[1]) * batch_var.detach())
self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * x.numel()/(x.numel()-x.shape[1]) * batch_var.detach())
self.num_batches_tracked += 1
return x.batchnorm(self.weight, self.bias, batch_mean, batch_var.add(self.eps).rsqrt())
BatchNorm2d = BatchNorm3d = BatchNorm
@ -108,8 +108,7 @@ class Conv2d:
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
self.bias: Optional[Tensor] = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
def __call__(self, x:Tensor) -> Tensor:
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
def __call__(self, x:Tensor) -> Tensor: return x.conv2d(self.weight, self.bias, self.groups, self.stride, self.dilation, self.padding)
def ConvTranspose1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding=0, output_padding=0, dilation=1,
groups=1, bias=True) -> ConvTranspose2d:
@ -154,8 +153,7 @@ class ConvTranspose2d(Conv2d):
self.output_padding = output_padding
def __call__(self, x:Tensor) -> Tensor:
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride,
dilation=self.dilation, groups=self.groups)
return x.conv_transpose2d(self.weight, self.bias, self.groups, self.stride, self.dilation, self.padding, self.output_padding)
class Linear:
"""
@ -178,8 +176,7 @@ class Linear:
self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
def __call__(self, x:Tensor) -> Tensor:
return x.linear(self.weight.transpose(), self.bias)
def __call__(self, x:Tensor) -> Tensor: return x.linear(self.weight.transpose(), self.bias)
class GroupNorm:
"""
@ -210,7 +207,7 @@ class GroupNorm:
if self.weight is None or self.bias is None: return x
# elementwise_affine on channels
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
return x * self.weight.reshape(1, -1, *[1] * (x.ndim-2)) + self.bias.reshape(1, -1, *[1] * (x.ndim-2))
class InstanceNorm:
"""
@ -237,7 +234,7 @@ class InstanceNorm:
def __call__(self, x:Tensor) -> Tensor:
x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
if self.weight is None or self.bias is None: return x
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
return x * self.weight.reshape(1, -1, *[1] * (x.ndim-2)) + self.bias.reshape(1, -1, *[1] * (x.ndim-2))
class LayerNorm:
"""
@ -257,7 +254,7 @@ class LayerNorm:
```
"""
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps=1e-5, elementwise_affine=True):
self.normalized_shape: Tuple[int, ...] = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
self.normalized_shape: Tuple[int, ...] = make_tuple(normalized_shape, 1)
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
self.weight: Optional[Tensor] = Tensor.ones(*self.normalized_shape) if elementwise_affine else None
self.bias: Optional[Tensor] = Tensor.zeros(*self.normalized_shape) if elementwise_affine else None