Issue
With python lists, we can do:
a = [1, 2, 3]
assert a.index(2) == 1
How can a pytorch tensor find the .index()
directly?
Solution
I think there is no direct translation from list.index()
to a pytorch function. However, you can achieve similar results using tensor==number
and then the nonzero()
function. For example:
t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero(as_tuple=True)[0])
This piece of code returns
1
[torch.LongTensor of size 1x1]
Answered By - Manuel Lagunas
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.