Issue
I'm new to Pytorch. Given a tensor set, I need to sort these tensors by the key value. For example,
A =
[[0.9133, 0.5071, 0.6222, 3.],
[0.5951, 0.9315, 0.6548, 1.],
[0.7704, 0.0720, 0.0330, 2.]]
My expected result after sorting is:
A' =
[[0.5951, 0.9315, 0.6548, 1.],
[0.7704, 0.0720, 0.0330, 2.],
[0.9133, 0.5071, 0.6222, 3.]]
I tried to use sorted function in python, but it was time-consuming in my training process. How to achieve it more efficiently? Thanks!
Solution
%%timeit -r 10 -n 10
A[A[:,-1].argsort()]
38.6 µs ± 23 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
%%timeit -r 10 -n 10
sorted(A, key = lambda x: x[-1])
69.6 µs ± 34.8 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
Both output
tensor([[0.5951, 0.9315, 0.6548, 1.0000],
[0.7704, 0.0720, 0.0330, 2.0000],
[0.9133, 0.5071, 0.6222, 3.0000]])
Then there is
%%timeit -r 10 -n 10
a, b = torch.sort(A, dim=-2)
The slowest run took 8.45 times longer than the fastest. This could mean that an intermediate result is being cached.
14.3 µs ± 18.1 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
with a
as the sorted tensor and b
as the indices
Answered By - warped
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.