Issue
Let's say there's a 2D Tensor:
data = tensor([[ 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15.]])
There are also two 1D Tensors for start and end indices of values in each row to be selected:
start = tensor([0., 3., 1.])
end = start + 2 # end is always at a +2 offset from start
Is there a way to select start[i]: end[i]
for the i
th row in data
that does not involve iterating over data
?
For the above example, the expected output is:
tensor([[ 1., 2.],
[ 9., 10.],
[12., 13.]])
Solution
This can be done provided the offset/chunksize is the same for all rows.
def index_function(data, start_index, chunksize, dim):
index_tensor = torch.stack([torch.arange(i, i+chunksize) for i in start_index])
result = data.gather(dim, index_tensor)
return result
data = torch.tensor([[ 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15.]])
start_idx = torch.tensor([0, 3, 1]) # start index tensor must be int, not float
index_function(data, start_idx, 2, 1)
>tensor([[ 1., 2.],
[ 9., 10.],
[12., 13.]])
start_idx = torch.tensor([0, 2, 1])
index_function(data, start_idx, 3, 1)
>tensor([[ 1., 2., 3.],
[ 8., 9., 10.],
[12., 13., 14.]])
Answered By - Karl
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.