Issue
Given an n-by-n matrix A, where each row of A is a permutation of [n], e.g.,
import torch
n = 100
AA = torch.rand(n, n)
A = torch.argsort(AA, dim=1)
Also given another n-by-n matrix P, we want to construct a 3D tensor Q s.t.
Q[i, j, k] = P[A[i, j], k]
Is there any efficient way in pytorch? I am aware of torch.gather but it seems hard to be directly applied here.
Solution
You can directly use:
Q = P[A]
Answered By - GoodDeeds
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.