Issue
I'm attempting feature extraction in an unorthodox way. I extract features in eval() mode to switch off the batch norm and dropout layers and use the running means and std provided by ImageNet.
I use a feature extractor to extract features from two related images and concatenate the two tensors stackwise before passing through a linear dense classifier model for training. I'm wondering whether I can avoid using with torch.no_grad()
as the two models are unrelated.
Here is a simplified version:
num_classes = 2
num_epochs = 10
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)
densenet= DenseNetConv()
# set densenet to eval to switch off batch norm and dropout layers and use ImageNet running means/ std devs
densenet.eval()
densenet.to(device)
classifier = nn.Linear(4416, num_classes)
classifier.to(device)
for epoch in range(num_epochs):
classifier.train()
for i, (inputs_1, inputs_2, labels) in enumerate(dataloaders_dict['train']):
inputs_1= inputs_1.to(device)
inputs_2 = inputs_2.to(device)
labels = labels.to(device)
features_1 = densenet(inputs_1) # extract features 1
features_2 = densenet(inputs_2) # extract features 2
combined = torch.cat([features_1, features_2], dim=1) # combine features
combined = combined(-1, 4416) # reshape
optimizer.zero_grad()
# Forward pass to get output/logits
outputs = classifier(combined)
# Calculate Loss: softmax --> cross entropy loss
loss = criterion(outputs, labels)
_, pred = torch.max(outputs, 1)
equality_check = (labels.data == pred)
# Getting gradients w.r.t. parameters
loss.backward()
optimizer.step()
As you can see, I do not call with torch.no_grad()
, despite having densenet.eval()
as my separate feature extractor. Is there an issue with the way this is implemented or can I assume that this will not interfere with the classifier
model?
Solution
If you are doing inference on a model, applying torch.no_grad()
won't have any effect on the resulting output. As you've said only nn.Module.eval
will since it modifies how the forward operation is performed (namely which statistics to use to normalize the batch elements).
It is recommended to switch off gradient computation when backpropagation is not necessary. This avoids caching activations on forward call resulting in faster inference time.
In your case, you can either wrap your inference call on
densenet
withtorch.no_grad
:torch.no_grad(): features_1 = densenet(inputs_1) # extract features 1 features_2 = densenet(inputs_2) # extract features 2
Or alternatively, switch off the
requires_grad
flag on your module's parameter tensors usingnn.Module.requires_grad_
:densenet.eval() densenet.requires_grad_(False)
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.