Issue
Hello I have the following code:
import torch
x = torch.zeros(1,8,4,576) # create a 4 dimensional tensor
x[0,4,2,333] = 1.0 # put on 1 on a random spot
# I want to find the index of the highest value (0,4,2,333)
print(x.argmax()) # this should return the index
This returns
tensor(10701)
How does this 10701 make sense?
How do I get the actual indices 0,4,2,333 ?
Solution
The data in the 4-dimensional array is stored linearly in memory, and argmax()
returns the corresponding index of this flat representation.
Numpy has a function for unraveling the index (converting from the flat array index to the corresponding multi-dimensional indices).
import numpy as np
np.unravel_index(10701, (1,8,4,576))
Answered By - dannyadam
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.