Issue
I have a pytorch tensor of size torch.Size([4, 3, 2])
tensor([[[0.4003, 0.2742],
[0.9414, 0.1222],
[0.9624, 0.3063]],
[[0.9600, 0.5381],
[0.5758, 0.8458],
[0.6342, 0.5872]],
[[0.5891, 0.9453],
[0.8859, 0.6552],
[0.5120, 0.5384]],
[[0.3017, 0.9407],
[0.4887, 0.8097],
[0.9454, 0.6027]]])
I would like to delete the 2nd row so that the tensor becomes torch.Size([3, 3, 2])
tensor([[[0.4003, 0.2742],
[0.9414, 0.1222],
[0.9624, 0.3063]],
[[0.5891, 0.9453],
[0.8859, 0.6552],
[0.5120, 0.5384]],
[[0.3017, 0.9407],
[0.4887, 0.8097],
[0.9454, 0.6027]]])
How can I delete the nth row of the 3D tensor?
Solution
import torch
x = torch.randn(size=(4,3,2))
row_exclude = 2
x = torch.cat((x[:row_exclude],x[row_exclude+1:]))
print(x.shape)
>>> torch.Size([3, 3, 2])
Answered By - Vinson Ciawandy
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.