Issue
I want to use the pre-trained models in Pytorch to do image classification in my own datasets, but how should I change the number of classes while freezing the parameters of the feature extraction layer?
These are the models I want to include:
resnet18 = models.resnet18(pretrained=True)
densenet161 = models.densenet161(pretrained=True)
inception_v3 = models.inception_v3(pretrained=True)
shufflenet_v2_x1_0 = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
mnasnet1_0 = models.mnasnet1_0(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
Thanks a lot in advance!
New codes I added:
import torch
from torchvision import models
class MyResModel(torch.nn.Module):
def __init__(self):
super(MyResModel, self).__init__()
self.classifier = nn.Sequential(
nn.Linear(512,256),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(256,3),
)
def forward(self, x):
return self.classifier(x)
resnet18 = models.resnet18(pretrained=True)
resnet18.fc = MyResModel()
for param in resnet18.parameters():
param.requires_grad_(False)
Solution
You have to change the final Linear layer of the respective model.
For example in the case of resnet, when we print the model, we see that the last layer is a fully connected layer as shown below:
(fc): Linear(in_features=512, out_features=1000, bias=True)
Thus, you must reinitialize model.fc to be a Linear layer with 512 input features and 2 output features with:
model.fc = nn.Linear(512, num_classes)
For other models you can check here
To freeze the parameters of the network you have to use the following code:
for name, param in model.named_parameters():
if 'fc' not in name:
print(name, param.requires_grad)
param.requires_grad=False
To validate:
for name, param in model.named_parameters():
print(name,param.requires_grad)
Note that for this example 'fc' was the name of the classification layer. This is not the case for other models. You have to inspect the model in order to find the name of the classification layer.
Answered By - xro7
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.