Issue
Why the loss function is always printing zero after the first epoch?
I suspect it's because of loss = loss_fn(outputs, torch.max(labels, 1)[1])
.
But if I use loss = loss_fn(outputs, labels)
, I will get the error
RuntimeError: 0D or 1D target tensor expected, multi-target not supported
.
nepochs = 5
losses = np.zeros(nepochs)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(modell.parameters(), lr = 0.001)
for epoch in range(nepochs):
running_loss = 0.0
n = 0
for data in train_loader:
#single batch
if(n == 1):
break;
inputs, labels = data
optimizer.zero_grad()
outputs = modell(inputs)
#loss = loss_fn(outputs, labels)
loss = loss_fn(outputs, torch.max(labels, 1)[1])
loss.backward()
optimizer.step()
running_loss += loss.item()
n += 1
losses[epoch] = running_loss / n
print(f"epoch: {epoch+1} loss: {losses[epoch] : .3f}")
The model is:
def __init__(self, labels=10):
super(Classifier, self).__init__()
self.fc = nn.Linear(3 * 64 * 64, labels)
def forward(self, x):
out = x.reshape(x.size(0), -1)
out = self.fc (out)
return out
Any idea?
The labels are a 64 elements tensor like this:
tensor([[7],[1],[ 2],[3],[ 2],[9],[9],[8],[9],[8],[ 1],[7],[9],[2],[ 5],[1],[3],[3],[8],[3],[7],[1],[7],[9],[8],[ 8],[3],[7],[ 5],[ 1],[7],[3],[2],[1],[ 3],[3],[2],[0],[3],[4],[0],[7],[1],[ 8],[4],[1],[ 5],[ 3],[4],[3],[ 4],[8],[4],[1],[ 9],[7],[3],[ 2],[ 6],[4],[ 8],[3],[ 7],[3]])
Solution
Usually loss calculation is loss = loss_fn(outputs, labels)
and here outputs
is as following:
_ , outputs = torch.max(model(input), 1)
or
outputs = torch.max(predictions, 1)[0]
Common practice is modifying outputs
instead of labels
:
torch.max()
returns a namedtuple(values, indices)
where values is the maximum value of each row of theinput
tensor in the given dimensiondim
. Andindices
is the index location of each maximum value found (argmax
).
In your code snippet the labels
is not indices of the labels, so when you calculate the loss, the function should look like this:
loss = loss_fn(torch.max(outputs, 1)[0], labels)
Answered By - yakhyo
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.