Issue
I build a simple GRU model with PyTorch. It includes 4 sub-modules.
I noted that some dictionaries return by the state_dict()
of them are empty after training, while ones of the other sub-modules certainly have some weights and bias.
The code:
class GruModel(nn.Module):
def __init__(self, inputs, nodes, layers=2):
super(GruModel, self).__init__()
self.gru_m = nn.GRU(input_size=inputs, num_layers=layers, hidden_size=nodes,
batch_first=True, dropout=0.5)
self.activt_f = nn.ReLU()
self.output_f = nn.Linear(nodes, 1)
self.probab_f = nn.Sigmoid()
def forward(self, x, h):
o, h = self.gru_m(x, h)
o = self.activt_f(o[:, -1])
out = self.output_f(o)
return self.probab_f(out)
def trainWith(self, ...):
''' training body '''
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adadelta(self.parameters(), lr=learn_rat)
lr_schdlr = torch.optim.lr_scheduler.ExponentialLR(
optimizer, gamma=0.99)
t_loader = torch_data.DataLoader(...)
for e in range(epochs):
for x, p_label, n_label in t_loader:
optimizer.zero_grad()
out = self(x, self.init_hidden(batch_size))
loss = criterion(out, p_label)
loss.backward()
optimizer.step()
lr_schdlr.step()
def save(self, full_path: str):
print(self.gru_m.state_dict())
print(self.activt_f.state_dict())
print(self.output_f.state_dict())
print(self.probab_f.state_dict())
In real running, the state_dict of the sub-module self.gru_m
and self.output_f
have values as expected, but the ones of the sub-module self.activt_f
(nn.ReLU) and self.probab_f
(nn.Sigmoid) have nothing.
Don't care about my training process, I feed it with tons of data and run through hundreds of epochs, and the model can do classification as I expected.
I'm interested in whether the later two modules are trainable, or they do NOT need any weights and bias with them to do their work?
If so, can we say that the torch.nn.Sigmoid
is same as torch.nn.functional.sigmoid
? Because they are all dummy functions, not stateful objects.
Solution
The two layer modules you are mentioning are activation functions which are not parametrized. This means they are not "trainable" since they don't hold any parameters.
However, nn
modules are classes (they can be stateful) while nn.functional
utilities are functions (they are not stateful).
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.