don't increment num_batches_tracked if not tracking running stats

This commit is contained in:
David Hou 2024-02-29 13:45:24 -08:00
commit 78de0ea9ee

View file

@ -39,7 +39,7 @@ class BatchNorm2d:
self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
self.running_var.assign((1-self.momentum) * self.running_var +
self.momentum * prod(y.shape[1:])/(prod(y.shape[1:])-y.shape[2]) * batch_var.detach())
self.num_batches_tracked += 1
self.num_batches_tracked += 1
else:
batch_mean = self.running_mean
# NOTE: this can be precomputed for static inference. we expand it here so it fuses