Issue
Suppose, I have a 3D tensor A
A = torch.arange(24).view(4, 3, 2)
print(A)
and require masking it using 2D tensor
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
Using masked_select functionality from PyTorch leads to the following error.
torch.masked_select(X, (mask == 1))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-72-fd6809d2c4cc> in <module>
12
13 # Select based on new mask
---> 14 Y = torch.masked_select(X, (mask == 1))
15 #Y = X * mask_
16 print(Y)
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 2
How to mask a 3D tensor with a 2D mask and keep the dimensions of the original vector? Any hints will be appreciated.
Solution
Essentially, we need to match the dimension of the tensor mask with the tensor being masked.
There are two ways to do it.
Approach 1: Does not preserve original tensor dimensions.
X = torch.arange(24).view(4, 3, 2)
print(X)
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)
# Select based on the new expanded mask
Y = torch.masked_select(X, (mask_ == 1)) # does not preserve the dims
print(Y)
The output for approach 1:
tensor([ 0, 1, 8, 9, 18, 19])
Approach 2: Preserves the original tensor dimensions (by padding).
X = torch.arange(24).view(4, 3, 2)
print(X)
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)
# Select based on the new expanded mask
Y = X * mask_
print(Y)
The output for approach 2:
tensor([[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]])
Mask: tensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 0],
[1, 0, 0]])
tensor([[[1, 1],
[0, 0],
[0, 0]],
[[0, 0],
[1, 1],
[0, 0]],
[[0, 0],
[0, 0],
[0, 0]],
[[1, 1],
[0, 0],
[0, 0]]])
tensor([[[ 0, 1],
[ 0, 0],
[ 0, 0]],
[[ 0, 0],
[ 8, 9],
[ 0, 0]],
[[ 0, 0],
[ 0, 0],
[ 0, 0]],
[[18, 19],
[ 0, 0],
[ 0, 0]]]
Answered By - Anjani Anjani
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.