simpler batchnorm

This commit is contained in:
George Hotz 2022-06-05 16:51:51 -07:00
commit ebd7290f56

View file

@ -30,7 +30,7 @@ class BatchNorm2D:
def normalize(self, x, mean, var):
x = (x - mean.reshape(shape=[1, -1, 1, 1])) * self.weight.reshape(shape=[1, -1, 1, 1])
return x.div(var.add(self.eps).reshape(shape=[1, -1, 1, 1])**0.5) + self.bias.reshape(shape=[1, -1, 1, 1])
return x.mul(var.add(self.eps).reshape(shape=[1, -1, 1, 1])**-0.5) + self.bias.reshape(shape=[1, -1, 1, 1])
class Conv2d:
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):