Issue
The numpy.where()
function will always return a list in ascending order. I am using it to get indices from values in a list that are found in another list like this:
lst = [1, 2, 8, 7, 3, 4, 6, 5]
values = [4, 8]
indices = np.where(np.isin(lst, values))
The output is [2,5]
instead of the expected [5,2]
.
The problem, is that I need it to return the indices in the order that the values are found. So in my example, 4 is the first value, so the index that corresponds to it (5)
should be first in the output, but it isn't. Is there an alternative to the numpy.where()
function that keeps the order instead of sorting it? I could simply use some for loops and basic indexing to figure this out, but speed is super important because this is part of my custom 2d Max Pool function for a neural net.
Solution
Assuming values
is in sorted order, you can reorder the indices based on the output values:
lst = [1, 2, 8, 7, 3, 4, 6, 5]
values = [4, 8]
a = np.array(lst)
m = np.isin(a, values)
order = np.argsort(a[m], kind='stable')
indices = np.where(np.isin(lst, values))[0][order]
Output: [5, 2]
Output with lst = [1, 2, 8, 7, 3, 4, 6, 5, 4, 8]
: array([5, 8, 2, 9])
alternative
Another approach that would work with any order in values
would be to broadcast the comparison:
# 0 1 2 3 4 5 6 7 8 9
lst = [1, 2, 8, 7, 3, 4, 6, 5, 4, 8]
values = [4, 8]
a, b = np.where(np.array(lst) == np.array(values)[:,None])
indices = b[np.argsort(a)]
# array([5, 8, 2, 9])
#### other example
# 0 1 2 3 4 5 6 7 8 9
lst = [1, 2, 8, 7, 3, 4, 6, 5, 4, 8]
values = [8, 4]
a, b = np.where(np.array(lst) == np.array(values)[:,None])
indices = b[np.argsort(a)]
# array([2, 9, 5, 8])
NB. My previous statement was incorrect. I had misread unique
instead of where
.
Answered By - mozway
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.