Issue
Very simple question but I have been struggling with this forever now.
import torch
t = torch.tensor([[2,3],[4,6]])
overlap = [2, 6]
f = lambda x: x in overlap
I want:
torch.tensor([[True,False],[False,True]])
Both the tensor and overlap are very big, so efficiency is wished here.
Solution
I found an easy way. Since torch is implemented through numpy array the following works and is performant:
import torch
import numpy as np
t = torch.tensor([[2,3],[4,6]])
overlap = [2, 6]
f = lambda x: x in overlap
mask = np.vectorize(f)(t)
Found here.
Answered By - Marcel Braasch
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.