Issue
New to tensors/pytorch.
I have two 2d tensors, A and B.
A contains floats that represent the probability assigned to a certain index. B contains a one-hot binary vector in the correct index.
A
tensor([[0.1, 0.4, 0.5],
[0.5, 0.4, 0.1],
[0.4, 0.5, 0.1]])
B
tensor([[0, 0, 1],
[0, 1, 0],
[0, 0, 1]])
I would like to find the number of rows where the index of any top-k values of A match the one-hot index in B. In this case, k=2.
My attempt:
tops = torch.topk(A, 2, dim=1)
top_idx = tops.indices
top_2_matches = torch.where((torch.any(top_idx, 1) == B.argmax(dim=1)))
If done properly, the example should return a tensor([0, 1]), since the first 2 rows have top-2 matches, but I get (tensor([1]),)
as a return.
Not sure where I'm going wrong here. Thanks for any help!
Solution
Try this:
top_idx = torch.topk(A, 2, dim=1).indices
row_indicator = (top_idx == B.argmax(dim=1).unsqueeze(dim=1)).any(dim=1)
top_2_matches = torch.arange(len(row_indicator))[row_indicator]
For example:
>>> import torch
>>> A = torch.tensor([[0.1, 0.4, 0.5],
... [0.5, 0.4, 0.1],
... [0.4, 0.5, 0.1]])
>>> B = torch.tensor([[0, 0, 1],
... [0, 1, 0],
... [0, 0, 1]])
>>> tops = torch.topk(A, 2, dim=1)
>>>tops
torch.return_types.topk(
values=tensor([[0.5000, 0.4000],
[0.5000, 0.4000],
[0.5000, 0.4000]]),
indices=tensor([[2, 1],
[0, 1],
[1, 0]]))
>>> top_idx = tops.indices
>>> top_idx
tensor([[2, 1],
[0, 1],
[1, 0]])
>>> index_indicator = top_idx == B.argmax(dim=1).unsqueeze(dim=1)
>>> index_indicator
tensor([[ True, False],
[False, True],
[False, False]])
>>> row_indicator = index_indicator.any(dim=1)
>>> row_indicator
tensor([ True, True, False])
>>> top_2_matches = torch.arange(len(row_indicator))[row_indicator]
>>> top_2_matches
tensor([0, 1])
Answered By - A. Maman
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.