Issue
In my network, I want to calculate the forward pass and backward pass of my network both in the forward pass.
For this, I have to manually define all the backward pass methods of the forward pass layers.
For the activation functions, that's easy. And also for the linear and conv layers, it worked well. But I'm really struggling with BatchNorm. As the BatchNorm paper only discusses the 1D case:
So far, my implementation looks like this:
def backward_batchnorm2d(input, output, grad_output, layer):
gamma = layer.weight
beta = layer.bias
avg = layer.running_mean
var = layer.running_var
eps = layer.eps
B = input.shape[0]
# avg, var, gamma and beta are of shape [channel_size]
# while input, output, grad_output are of shape [batch_size, channel_size, w, h]
# for my calculations I have to reshape avg, var, gamma and beta to [batch_size, channel_size, w, h] by repeating the channel values over the whole image and batches
dL_dxi_hat = grad_output * gamma
dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0, 2, 3), keepdim=True)
dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0, 2, 3), keepdim=True) / B
dL_dxi = dL_dxi_hat / torch.sqrt(var + eps) + 2.0 * dL_dvar * (input - avg) / B + dL_davg / B # dL_dxi_hat / sqrt()
dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True)
dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
return dL_dxi, dL_dgamma, dL_dbeta
When I check my gradients with torch.autograd.grad() I notice that dL_dgamma
and dL_dbeta
are correct, but dL_dxi
is incorrect, (by a lot). But I can't find my mistake. Where is my mistake?
For reference, here is the definition of BatchNorm:
And here are the formulas for the derivatives for the 1D case:
Solution
def backward_batchnorm2d(input, output, grad_output, layer):
gamma = layer.weight
gamma = gamma.view(1,-1,1,1) # edit
# beta = layer.bias
# avg = layer.running_mean
# var = layer.running_var
eps = layer.eps
B = input.shape[0] * input.shape[2] * input.shape[3] # edit
# add new
mean = input.mean(dim = (0,2,3), keepdim = True)
variance = input.var(dim = (0,2,3), unbiased=False, keepdim = True)
x_hat = (input - mean)/(torch.sqrt(variance + eps))
dL_dxi_hat = grad_output * gamma
# dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0, 2, 3), keepdim=True)
# dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0, 2, 3), keepdim=True) / B
dL_dvar = (-0.5 * dL_dxi_hat * (input - mean)).sum((0, 2, 3), keepdim=True) * ((variance + eps) ** -1.5) # edit
dL_davg = (-1.0 / torch.sqrt(variance + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + (dL_dvar * (-2.0 * (input - mean)).sum((0, 2, 3), keepdim=True) / B) #edit
dL_dxi = (dL_dxi_hat / torch.sqrt(variance + eps)) + (2.0 * dL_dvar * (input - mean) / B) + (dL_davg / B) # dL_dxi_hat / sqrt()
# dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True)
dL_dgamma = (grad_output * x_hat).sum((0, 2, 3), keepdim=True) # edit
dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
return dL_dxi, dL_dgamma, dL_dbeta
- Because you didn't upload your forward snipcode, so if your gamma has the shape size is
1
, you need to reshape it to[1,gamma.shape[0],1,1]
. - The formula follows 1D where the scale factor is the sum of the batch size. However, in 2D, the summation should between 3 dimensions, so
B = input.shape[0] * input.shape[2] * input.shape[3]
. - The
running_mean
andrunning_var
only use in test/inference mode, we don't use them in training (you can find it in the paper). The mean and variance you need are computed from the input, you can store the mean, variance andx_hat = (x-mean)/sqrt(variance + eps)
into your objectlayer
or re-compute as I did in the code above# add new
. Then replace them with the formula ofdL_dvar, dL_davg, dL_dxi
. - your
dL_dgamma
should be incorrect since you multiplied the gradient ofoutput
by itself, it should be modified tograd_output * x_hat
.
Answered By - CuCaRot
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.