Issue
I was wondering how can I keep my batch norm layers active using an untrained feature extraction network.
Would this be considered feature extraction with an "untrained" network?:
class DenseNetConv(torch.nn.Module):
def __init__(self):
super(DenseNetConv,self).__init__()
original_model = models.densenet161(pretrained=False)
self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
self.avgpool = nn.AdaptiveAvgPool2d(1)
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
x = self.features(x)
x = F.relu(x, inplace=True)
x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), -1)
return x
The above should return a tensor of [batch size, 2208], however, I want to make sure that by stating pretrained=False
that I am basically extracting features from an untrained network.
I then use the following to define the classifier layers:
class MyDenseNetDens(torch.nn.Module):
def __init__(self, nb_out=2):
super().__init__()
self.dens1 = torch.nn.Linear(in_features=2208, out_features=512)
self.dens2 = torch.nn.Linear(in_features=512, out_features=128)
self.dens3 = torch.nn.Linear(in_features=128, out_features=nb_out)
def forward(self, x):
x = self.dens1(x)
x = torch.nn.functional.selu(x)
x = F.dropout(x, p=0.25, training=self.training)
x = self.dens2(x)
x = torch.nn.functional.selu(x)
x = F.dropout(x, p=0.25, training=self.training)
x = self.dens3(x)
return x
and finally join them together here:
class MyDenseNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.mrnc = MyDenseNetConv()
self.mrnd = MyDenseNetDens()
def forward(self, x):
x = self.mrnc(x)
x = self.mrnd(x)
return x
densenet = MyDenseNet()
densenet.to(device)
densenet.train()
If I allow this to train, for example by applying densenet.train()
will this be sufficient in allowing batch normalisation statistics to be generated for each mini-batch as well as allow for the running means and standard deviations to be learnt and applied during inference, while keeping the convolutional layers untrained?
Solution
Unfortunately, it will update the running stat anyway since the default batchnorm in your self.mrnc
still initialized with the running mean and the running variance. All you need to do is shut them down:
class MyDenseNetConv(torch.nn.Module): # I renamed it to be correct to your code
def __init__(self):
super(MyDenseNetConv,self).__init__()
original_model = models.densenet161(pretrained=False)
for n, v in original_model.named_modules(): # this part is to remove the tracking stat option
if 'norm' in n:
v.track_running_stats=False
v.running_mean = None
v.running_var = None
v.num_batches_tracked = None
self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
self.avgpool = nn.AdaptiveAvgPool2d(1)
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
x = self.features(x)
x = F.relu(x, inplace=True)
x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), -1)
return x
Answered By - CuCaRot
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.