Issue
I try to freeze the batch_norm layer and analyse their inputs/outputs with forward hooks
For fixed BN layers, I just couldn't understand why the hooked output is different from the output reproduced by the hooked input.
Really appreciate that if anyone could help me
Here's the code:
import torch
import torchvision
import numpy
def set_bn_eval(m):
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.eval()
image = torch.randn((1, 3, 224, 224))
res = torchvision.models.resnet50(pretrained=True)
res.apply(set_bn_eval)
b = res(image)
layer_out = []
layer_in = []
def layer_hook(mod, inp, out):
layer_out.append(out)
layer_in.append(inp[0])
for name, key in res.named_modules():
hook = key.register_forward_hook(layer_hook)
res(image)
hook.remove()
out = layer_out.pop()
inp = layer_in.pop()
try:
assert (out.equal(key(inp)))
except AssertionError:
print(name)
break
Solution
TLDR; Some operators will only appear in the forward
of the module: such as non-parametrized layers.
Some components are not registered in the child module list. This can usually be the case for activation functions but will ultimately depend on the module implementation. In your case, ResNet's Bottleneck section as its ReLUs applied in the forward definition, just after the batch normalization layer is called.
This means the output you will catch with the layer hook will be different from the tensor you compute from just the module and its input.
for name, module in res.named_modules():
if name != 'bn1':
hook = module.register_forward_hook(layer_hook)
res(image)
hook.remove()
inp = layer_in.pop()
out = layer_out.pop()
assert out.equal(F.relu(module(inp)))
Therefore, it's a bit tricky to actually implement since you can't rely entirely on the content of res.named_modules()
.
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.