Issue
I have a 3-dimensional array/tensor of shape (a, b, c)
, and I have a list of length a
of different indices, each in the range [0, b)
. I want to use the indices to get an array of size (a, c)
. Right now I do this with an ugly list comprehension
z = torch.stack([t_[b, :] for t_, b in zip(tensor, B)])
This is implemented in a forward pass for a neural network, so I really want to avoid a list comprehension. Is there any torch (or numpy) function that does what I want more efficient?
Also a small example:
tensor = [[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]] # shape: (4, 3, 2)
B = [0, 1, 2, 2]
output = [[ 0, 1],
[ 8, 9],
[16, 17],
[22, 23]] # shape (4, 2)
Background: I have time series data which has time windows of different lengths. I use torch's pack_padded_sequence
(and reverse) to mask it, but I have to get the output of the LSTM
at the time step before the masking starts, because then the output of the network gets unusable. In the example, I would have 4 time steps with length 0, 1, 2, 2
each with 2 features.
Solution
Use advanced indexing. To get the desired output, we need the corresponding indices for the first axis, which is created using torch.arange()
below:
output = tensor[torch.arange(len(B)), B]
or using numpy
output = tensor[np.arange(len(B)), B]
both produce:
tensor([[ 0, 1],
[ 8, 9],
[16, 17],
[22, 23]])
Full code using example:
import torch
tensor = torch.tensor([
[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]])
B = [0, 1, 2, 2]
output = tensor[torch.arange(len(B)), B]
Answered By - cottontail
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.