Issue
I have a (100,64,22,3,3) shaped pytorch tensor, and I would like to sort along axis=0 by the trace of the (3,3) components. The code I have below works, but it is very slow due to the for loops. Is there a way to vectorize the operation to speed it up?
x=torch.rand(100,64,22,3,3)
x_sorted=torch.zeros((x.shape[0],x.shape[1],x.shape[2],x.shape[3],x.shape[4]))
for i in range(x.shape[0]):
#compute tensorized trace
trace=new=torch.diagonal(x[i], dim1=-2, dim2=-1).sum(-1)
#Sort the trace
trace_values,trace_ind=torch.sort(trace,dim=0,descending=True)
for j in range(x_sorted.shape[1]):
for k in range(x_sorted.shape[2]):
x_sorted[i,j,k]=x[i,trace_ind[j,k],k]
Solution
Try:
tensor = torch.tensor(np.random.rand(100,64, 3, 3))
orders = torch.argsort(torch.einsum('ijkk->ijk', tensor).sum(-1), axis=0)
orders.shape
tensor[orders, torch.arange(s.shape[1])[None, :]]
Answered By - Quang Hoang
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.