Issue
I'd like to perform fine-tuning of an entire block from DenseNet-161. At the moment, I know I can use the following to freeze all layers apart from the classifier:
model = models.densenet161(pretrained=True)
for param in model.parameters():
param.requires_grad = False
num_ftrs = model.classifier.in_features
model.classifier = torch.nn.Linear(num_ftrs,2)
However, I'd like to unfreeze the last few layers/ block of the DenseNet for fine-tuning. What would be the best most elegant way of achieving this?
Solution
First of all, you can also unfreeze the classifier by setting requires_grad
of it's parameters to True
.
for param in model.classifier.parameters():
param.requires_grad = True
This way you keep the original parameters of that layer, instead of a new random initialization that you get when create a new nn.Linear
.
That also works for any other submodule of the DenseNet. You can see which other modules there are by printing the module. To unfreeze the last block and the last BatchNorm, you can do
# this is a torch.nn.Sequential containing the
# "denseblock4" and "norm5" submodules
submodules = model.features[-2:]
for param in submodules.parameters():
param.requires_grad = True
If you want to reset the parameters to a new random initialization, you can use some initializer from torch.nn.init
on each parameter.
As requested in the comments: How to re-initialize the last two layers while keeping them frozen?
The last two layers contain convolutional layers and batch norm layers. While you probably want to reinitialize the convolutional layers randomly, this may not be what you want for the batch norm layers.
with torch.no_grad(): # allows to re-initialize the parameters
submodules = model.features[-2:]
for submodule in submodules.modules():
if isinstance(submodule, torch.nn.Conv2d):
# randomly re-initialize the weights
torch.nn.init.kaiming_normal_(submodule.weight)
if submodule.bias is not None:
# reset the bias to zero
torch.nn.init.zeros_(submodule.bias)
elif isinstance(submodule, torch.nn.BatchNorm2d):
torch.nn.init.ones_(submodule.weight)
torch.nn.init.zeros_(submodule.bias)
# also reset running mean and running_var
torch.nn.init.zeros_(submodule.running_mean)
torch.nn.init.ones_(submodule.running_var)
We haven't frozen or un-frozen the parameters in this code. They retain the state they had initially. You can either freeze them before or afterwards using the usual procedure.
Answered By - cherrywoods
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.