Issue
I'd like to know what exactly the running_mean
and running_var
that I can call from nn.BatchNorm2d
.
Example code is here where bn means nn.BatchNorm2d
.
vector = torch.cat([
torch.mean(self.conv3.bn.running_mean).view(1), torch.std(self.conv3.bn.running_mean).view(1),
torch.mean(self.conv3.bn.running_var).view(1), torch.std(self.conv3.bn.running_var).view(1),
torch.mean(self.conv5.bn.running_mean).view(1), torch.std(self.conv5.bn.running_mean).view(1),
torch.mean(self.conv5.bn.running_var).view(1), torch.std(self.conv5.bn.running_var).view(1)
])
I couldn't figure out what running_mean
and running_var
mean in the Pytorch official documentation and user community.
What do nn.BatchNorm2.running_mean
and nn.BatchNorm2.running_var
mean?
Solution
From the original Batchnorm paper:
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift,
Seguey Ioffe and Christian Szegedy, ICML'2015
You can see on Algorithm 1. how to measure the statistics of a given batch.
However what is kept in memory across batches is the running stats, i.e. the statistics which are measured iteratively at each batch inference. The computation of the running mean and running variance is actually quite well explained in the documentation page of nn.BatchNorm2d
:
By default, the momentum
coefficient is set to 0.1
, it regulates how much of the current batch statistics will affect the running statistics:
closer to
1
means the new running stat is closer to the current batch statistics, whereascloser to
0
means the current batch stats will not contribute much to updating the new running stats.
It's worth pointing out that Batchnorm2d
is applied across spatial dimensions, * in addition*, to the batch dimension of course. Given a batch of shape (b, c, h, w)
, it will compute the statistics across (b, h, w)
. This means the running statistics are shaped (c,)
, i.e. there are as many statistics components as there are in input channels (for both mean and variance).
Here is a minimal example:
>>> bn = nn.BatchNorm2d(10)
>>> x = torch.rand(2,10,2,2)
Since track_running_stats
is set to True
by default on BatchNorm2d
, it will track the running stats when inferring on training mode.
The running mean and variance are initialized to zeros and ones, respectively.
>>> running_mean, running_var = torch.zeros(x.size(1)),torch.ones(x.size(1))
Let's perform inference on bn
in training mode and check its running stats:
>>> bn(x)
>>> bn.running_mean, bn.running_var
(tensor([0.0650, 0.0432, 0.0373, 0.0534, 0.0476,
0.0622, 0.0651, 0.0660, 0.0406, 0.0446]),
tensor([0.9027, 0.9170, 0.9162, 0.9082, 0.9087,
0.9026, 0.9136, 0.9043, 0.9126, 0.9122]))
Now let's compute those stats by hand:
>>> (1-momentum)*running_mean + momentum*xmean
tensor([[0.0650, 0.0432, 0.0373, 0.0534, 0.0476,
0.0622, 0.0651, 0.0660, 0.0406, 0.0446]])
>>> (1-momentum)*running_var + momentum*xvar
tensor([[0.9027, 0.9170, 0.9162, 0.9082, 0.9087,
0.9026, 0.9136, 0.9043, 0.9126, 0.9122]])
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.