Issue
I'm using some code of ShuffleNet, but I have a problem with understanding the calculation of correct
in this function.(this function calculates precision 1 and 5).
As I understand in the third line pred
is the indices, but I can't understand why two lines later with equivalence function it has been compared with the target
, because pred
is indices of the most probabilities of output.
def accuracy(output, target, topk=(1,)):
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].contiguous().view(-1).float().sum(0)
res.append(correct_k.mul_(100.0/batch_size))
return res
Solution
Looking at the code, I can speculate output
to be shaped (batch_size, n_logits)
while the target is a dense representation: shaped (batch_size, 1)
. This means the ground truth class is designated by an integer value: the corresponding class label.
If we look into this implementation of the top-k
accuracy, we first need to understand this: top-k
accuracy is about counting how many ground truth labels are among the k highest predictions of our output. It's essentially a generalized form of the standard top-1
accuracy where we would only look at the single highest prediction and find out if it matches the target.
If we take a simple example with batch_size=2
, n_logits=10
, and k=3
i.e. we're interested in the top-3
accuracy. Here we sample a random prediction:
>>> output
tensor([[0.2110, 0.9992, 0.0597, 0.9557, 0.8316, 0.8407, 0.8398, 0.3631, 0.2889, 0.3226],
[0.6811, 0.2932, 0.2117, 0.6522, 0.2734, 0.8841, 0.0336, 0.7357, 0.9232, 0.2633]])
We first look at the k
highest logits and retrieve their indices:
>>> _, pred = output.topk(k=3, dim=1, largest=True, sorted=True)
>>> pred
tensor([[3, 6, 4],
[7, 3, 5]])
This is nothing more than a sliced torch.argsort
: output.argsort(1, descending=True)[:, :3]
will return the same result.
We can then transpose to get batches last (3, 2)
:
>>> pred = pred.T
tensor([[3, 7],
[6, 3],
[4, 5]])
Now that we have the top-k
predictions for each batch element, we need to compare those with the ground truths. Let us imagine now a target tensor (remember is shaped as (batch_size=2, 1)
):
>>> target
tensor([[1],
[5]])
We first need to expand it to the shape of pred
:
>>> target.view(1, -1).expand_as(pred)
tensor([[1, 0],
[1, 0],
[1, 0]])
We then compare eachother with torch.eq
, the element-wise equality operator:
>>> correct = torch.eq(pred, target.view(1, -1).expand_as(pred))
tensor([[False, False],
[False, False],
[False, True]])
As you can tell on the 2nd batch element, one of the highest three matches the ground-truth class label 5
. On the first batch element, neither of the three highest predictions matches the ground-truth label, it's not correct. The second batch element counts as one 'correct'.
Of course, based on this equality mask tensor correct
, you can slice it even more, to compute other top-k'
accuracies where k' <= k
. For instance k' = 1
:
>>> correct[:1]
tensor([[False, False]])
Here for the top-1
accuracy, we have zero correct instances out of the two batch elements.
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.