Issue
Given:
tensor([[6, 6],
[4, 8],
[7, 5],
[7, 4],
[6, 4]])
How do I find the index of rows with values [7,5]
?
In general, how do I search for indices of any values: full and partial row or column?
Solution
Try with this:
>>> (a[:, None] == torch.tensor([7, 5])).all(-1).any(-1).nonzero().flatten().item()
2
>>>
Answered By - U12-Forward
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.