Issue
I have a torch tensor like so:
a=[1, 234, 54, 6543, 55, 776]
and other tensors like so:
b=[234, 54]
c=[55, 776]
I want to create a new mask tensor where the values of a
will be true if there is another tensor (b
or c
) are equal to it.
For example, in the tensors we have above I would like to create the following masking tensor:
a_masked =[False, True, True, False, True, True]
# The first two True values correspond to tensor `b` while the last two True values
correspond to tensor `c`.
I have seen other methods to check whether a full tensor is contained in another but this isn't the case here.
Is there a torch way to do this efficiently?
Thanks!
Solution
Based on the answers to on the PyTorch forum here, you could explicitly use a for loop, e.g.,
import torch
a = torch.tensor([1, 234, 54, 6543, 55, 776])
b = torch.tensor([234, 54])
c = torch.tensor([55, 776])
a_masked = sum(a == i for i in b).bool() + sum(a == i for i in c).bool()
print(a_masked)
tensor([False, True, True, False, True, True])
However, there is actually a PyTorch isin
function, for which you could do:
a_masked = torch.isin(a, torch.cat([b, c]))
This is several times faster than the sum
method.
Answered By - Matt Pitkin
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.