Issue
I revived an error
Expected object of scalar type Long but got scalar type Int for argument #3 'index'
This is from this line.
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
I am not sure what to do as I tried to convert this to a long using several places. I tried putting a
.long
at the end as well as setting the dtype to be torch.long which still didn't work.
Very similar to this but he didn't do anything to get the answer "Expected Long but got Int" while running PyTorch script
I have change a lot of the code and here is my last rendition, but is now giving me the same issue.
def forward(self, inputs, targets):
"""
Args:
inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
targets: ground truth labels with shape (num_classes)
"""
log_probs = self.logsoftmax(inputs)
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
if self.use_gpu: targets = targets.to(torch.device('cuda'))
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (- targets * log_probs).mean(0).sum()
return loss
Solution
The dtype of your index argument (i.e., targets.unsqueeze(1).data.cpu()
) needs to be torch.int64
.
(The error message is a bit confusing: torch.long
doesn't exist. But "Long" in PyTorch internals means int64).
Answered By - Brennan Vincent
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.