mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
35c30f76f2
commit
dacb1ff38a
1 changed files with 7 additions and 10 deletions
|
|
@ -56,7 +56,7 @@ class BatchNorm:
|
|||
# NOTE: wow, this is done all throughout training in most PyTorch models
|
||||
if self.track_running_stats and Tensor.training:
|
||||
self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
|
||||
self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * prod(x.shape)/(prod(x.shape)-x.shape[1]) * batch_var.detach())
|
||||
self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * x.numel()/(x.numel()-x.shape[1]) * batch_var.detach())
|
||||
self.num_batches_tracked += 1
|
||||
return x.batchnorm(self.weight, self.bias, batch_mean, batch_var.add(self.eps).rsqrt())
|
||||
BatchNorm2d = BatchNorm3d = BatchNorm
|
||||
|
|
@ -108,8 +108,7 @@ class Conv2d:
|
|||
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
|
||||
self.bias: Optional[Tensor] = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
def __call__(self, x:Tensor) -> Tensor: return x.conv2d(self.weight, self.bias, self.groups, self.stride, self.dilation, self.padding)
|
||||
|
||||
def ConvTranspose1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding=0, output_padding=0, dilation=1,
|
||||
groups=1, bias=True) -> ConvTranspose2d:
|
||||
|
|
@ -154,8 +153,7 @@ class ConvTranspose2d(Conv2d):
|
|||
self.output_padding = output_padding
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride,
|
||||
dilation=self.dilation, groups=self.groups)
|
||||
return x.conv_transpose2d(self.weight, self.bias, self.groups, self.stride, self.dilation, self.padding, self.output_padding)
|
||||
|
||||
class Linear:
|
||||
"""
|
||||
|
|
@ -178,8 +176,7 @@ class Linear:
|
|||
self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
|
||||
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.linear(self.weight.transpose(), self.bias)
|
||||
def __call__(self, x:Tensor) -> Tensor: return x.linear(self.weight.transpose(), self.bias)
|
||||
|
||||
class GroupNorm:
|
||||
"""
|
||||
|
|
@ -210,7 +207,7 @@ class GroupNorm:
|
|||
|
||||
if self.weight is None or self.bias is None: return x
|
||||
# elementwise_affine on channels
|
||||
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
|
||||
return x * self.weight.reshape(1, -1, *[1] * (x.ndim-2)) + self.bias.reshape(1, -1, *[1] * (x.ndim-2))
|
||||
|
||||
class InstanceNorm:
|
||||
"""
|
||||
|
|
@ -237,7 +234,7 @@ class InstanceNorm:
|
|||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
|
||||
if self.weight is None or self.bias is None: return x
|
||||
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
|
||||
return x * self.weight.reshape(1, -1, *[1] * (x.ndim-2)) + self.bias.reshape(1, -1, *[1] * (x.ndim-2))
|
||||
|
||||
class LayerNorm:
|
||||
"""
|
||||
|
|
@ -257,7 +254,7 @@ class LayerNorm:
|
|||
```
|
||||
"""
|
||||
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps=1e-5, elementwise_affine=True):
|
||||
self.normalized_shape: Tuple[int, ...] = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
||||
self.normalized_shape: Tuple[int, ...] = make_tuple(normalized_shape, 1)
|
||||
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
|
||||
self.weight: Optional[Tensor] = Tensor.ones(*self.normalized_shape) if elementwise_affine else None
|
||||
self.bias: Optional[Tensor] = Tensor.zeros(*self.normalized_shape) if elementwise_affine else None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue