Issue
For example, I got a tensor:
tensor = torch.rand(12, 512, 768)
And I got an index list, say it is:
[0,2,3,400,5,32,7,8,321,107,100,511]
I wish to select 1 element out of 512 elements on dimension 2 given the index list. And then the tensor's size would become (12, 1, 768)
.
Is there a way to do it?
Solution
There is also a way just using PyTorch and avoiding the loop using indexing and torch.split
:
tensor = torch.rand(12, 512, 768)
# create tensor with idx
idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]
# convert list to tensor
idx_tensor = torch.tensor(idx_list)
# indexing and splitting
list_of_tensors = tensor[:, idx_tensor, :].split(1, dim=1)
When you call tensor[:, idx_tensor, :]
you will get a tensor of shape:
(12, len_of_idx_list, 768)
.
Where the second dimension depends on your number of indices.
Using torch.split
this tensor is split into a list of tensors of shape: (12, 1, 768)
.
So finally list_of_tensors
contains tensors of the shape:
[torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768])]
Answered By - MBT
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.