Issue
Suppose I had a PyTorch tensor such as:
import torch
x = torch.randn([3, 4, 5])
and I wanted to get a new tensor, with the same number of dimensions, containing everything from the final value of dimension 1. I could just do:
x[:, -1:, :]
However, if x
had an arbitrary number of dimensions, and I wanted to get the final values from a specific dimension, what is the best way to do it?
Solution
You can use index_select
:
torch.index_select(x, dim=dim, index=torch.tensor(x.size(dim) - 1))
The output tensor would contain the same number of dimensions as the input. You can use squeeze
on the dim
to get rid of the extra dimension:
torch.index_select(x, dim=dim, index=torch.tensor(x.size(dim) - 1)).squeeze(dim=dim)
Note: While select
returns a view of the input tensor, index_select
returns a new tensor.
Example:
In [1]: dim = 1
In [2]: x = torch.randn([3, 4, 5])
In [3]: torch.index_select(x, dim=1, index=torch.tensor(x.size(1) - 1)).shape
Out[3]: torch.Size([3, 1, 5])
In [4]: torch.index_select(x, dim=1, index=torch.tensor(x.size(1) - 1)).squeeze(dim=dim).shape
Out[4]: torch.Size([3, 5])
Answered By - heemayl
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.