Issue
I would like to customize inception_v3 to make it work for 4-channel input. I tried to modify first layer of inception v3 as below.
x=torch.randn((5,4,299,299))
model_ft=models.inception_v3(pretrained=True)
model_ft.Conv2d_1a_3x3.conv=nn.Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
print(x.shape)
print(model_ft.Conv2d_1a_3x3.conv)
out=model_ft(x)
but it produces the following error. I think the input shape and network are correctly modified, so I can't understand why it makes error. does anyone have any advice?
torch.Size([5, 4, 299, 299])
Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
RuntimeErrorTraceback (most recent call last)
<ipython-input-118-41c045338348> in <module>
29 print(model_ft.Conv2d_1a_3x3.conv)
30
---> 31 out=model_ft(x)
32 print(out)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.6/dist-packages/torchvision/models/inception.py in forward(self, x)
202 def forward(self, x: Tensor) -> InceptionOutputs:
203 x = self._transform_input(x)
--> 204 x, aux = self._forward(x)
205 aux_defined = self.training and self.aux_logits
206 if torch.jit.is_scripting():
/usr/local/lib/python3.6/dist-packages/torchvision/models/inception.py in _forward(self, x)
141 def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
142 # N x 3 x 299 x 299
--> 143 x = self.Conv2d_1a_3x3(x)
144 # N x 32 x 149 x 149
145 x = self.Conv2d_2a_3x3(x)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.6/dist-packages/torchvision/models/inception.py in forward(self, x)
474
475 def forward(self, x: Tensor) -> Tensor:
--> 476 x = self.conv(x)
477 x = self.bn(x)
478 return F.relu(x, inplace=True)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in forward(self, input)
441
442 def forward(self, input: Tensor) -> Tensor:
--> 443 return self._conv_forward(input, self.weight, self.bias)
444
445 class Conv3d(_ConvNd):
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
438 _pair(0), self.dilation, self.groups)
439 return F.conv2d(input, weight, bias, self.stride,
--> 440 self.padding, self.dilation, self.groups)
441
442 def forward(self, input: Tensor) -> Tensor:
RuntimeError: Given groups=1, weight of size [32, 4, 3, 3], expected input[5, 3, 299, 299] to have 4 channels, but got 3 channels instead
Solution
The error is due the param pretrained=True
.
Since you are using pretrained weights and you cannot edit the shape of pretrained weights to make its adjust for 4 channel
. Hence the error pops up
Plz use it in this way ( which will only load architecture)
x=torch.randn((5,4,299,299))
model_ft=models.inception_v3(pretrained=False)
model_ft.Conv2d_1a_3x3.conv=nn.Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
print(x.shape)
print(model_ft.Conv2d_1a_3x3.conv)
out=model_ft(x)
and it will work
Answered By - Prajot Kuvalekar
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.