Issue
I have torch tensor list that looks like below
tensor([[[-1.8510e-01, 1.3181e-01, 3.2903e-01, ..., 1.9867e-01,
5.1037e-03, 6.4071e-03],
[-4.6331e-01, 2.0216e-01, 2.7916e-01, ..., 2.6695e-01,
-1.3543e-02, 5.3604e-02],
[-3.8719e-01, 2.9603e-01, 2.5516e-01, ..., 1.7509e-01,
8.9148e-02, 3.7516e-02],
and the shape of this torch tensor is [500, 197, 768]
There are 500 images with 197*768 dimensions. I need to remove the instance of some images. Lets say if I remove 5 images then the shape will be [495, 197, 768]
Can anyone tell me how to remove this using index?
Solution
It depends on where along that dimension you want to remove the items.
To remove the first/last n
elements (using normal Python indexing):
new_data = data[n:] # Remove first n elements
new_data = data[:-n] # Remove last n elements
To remove n
items inside the tensor, you will need to specify a start-index s
(s+n
should not be larger than the length along that dimension):
new_data = torch.cat((data[:s], data[s+n:]), dim=0) # Remove n elements starting at s
To remove using a list of indices you could do the following:
indices = [6, 7, 9, 100, 204] # Arbitrary list of indices that should be removed
indices_to_keep = [i for i in range(data.shape[0]) if i not in indices] # List of indices that should be kept
new_data = data[torch.LongTensor(indices_to_keep)]
Answered By - willdalh
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.