Issue
I have a (2 x 1 x 2 x 2 x 2) dimensional array:
array([[[[[ 7., 9.],
[10., 11.]],
[[19., 18.],
[20., 16.]]]],
[[[[24., 5.],
[ 6., 10.]],
[[18., 11.],
[45., 12.]]]]])
The last two dimensions are H (height) and W (width) respectively. Now I have two separate arrays with indexing along H and W:
idx2=np.array([1, 1, 0, 1]) # index along H
idx3=np.array([1, 0, 0, 0]) # index along W
Therefore, in terms of last two dimensions, I'd like to extract the (1,1)th element from [[ 7.,9.],[10.,11.]]
, that is, 11; and the (1,0)th element from [[19.,18.],[20.,16.]]
, that is 20, and so on. The final result should be a (2 x 1 x 2) array:
array([[[11., 20.]],
[[24., 45.]]])
Thanks for any help!
Solution
A possible way of solving this using a combination of np.ravel_multi_index
and np.take_along_axis
.
You could unravel the indexes on the last two dimensions and use np.take_along_axis
on this flattened spatial dimension h x w
axis:
>>> flat_idx = (idx2*x.shape[-1]+idx3).reshape(*x.shape[:-2], 1)
>>> flat_idx
array([[[[3],
[2]]],
[[[0],
[2]]]])
Alternatively, you can choose to use the builtin np.ravel_multi_index
, but it is slightly longer:
>>> flat_idx = np.ravel_multi_index((idx2, idx3), x.shape[-2:]).reshape(*x.shape[:-2], 1)
Then flatten the last two dimensions of x
and gather the indices:
>>> res = np.take_along_axis(x.reshape(*x.shape[:-2], -1), flat_idx, -1)
>>> res
array([[[[11.],
[20.]]],
[[[24.],
[45.]]]])
At this point a reshape will be necessary:
>>> res.reshape(*x.shape[0:-2])
array([[[11., 20.]],
[[24., 45.]]])
If you infer the shapes, with idx2
and idx3
this comes down to:
>>> flat_idx = (idx2*2+idx3).reshape(2, 1, 2, 1)
>>> res = np.take_along_axis(x.reshape(2, 1, 2, 4), flat_idx, -1)
>>> res.reshape((2, 1, 2))
The above method can be used to handle a more general case with idx2
, idx3
, idx4
, ...
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.