Issue
I can do the following with a single int to retrieve a bool tensor:
import torch
a = torch.tensor([1,2,3])
a != 2
#tensor([ True, False, True])
Can I do the same with a list in plain pytorch? I.e.:
import torch
a = torch.tensor([1,2,3])
a not in [2,3]
#tensor([ True, False, False])
Thanks a lot for your time!
Solution
I think you want torch.isin
out = ~torch.isin(a, torch.tensor([2, 3]))
# or
out = torch.isin(a, torch.tensor([2, 3]), invert=True)
print(out)
tensor([ True, False, False])
Answered By - Ynjxsjmh
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.