Issue
Problem:
- I have M objects sampled at different frames and I want to calculate the distance between pairs at each frame. I store the distances as a multidimensional array
xij
with three axis, where the elementxij[t,i,j]
corresponds to the distance between the objectsi
andj
at timet
. For example, we could have:N = 10**5 M = 10 xij = np.random.uniform(0, 10, N).reshape(int(N/M**2), M, M)
- Now I want to calculate for each element the average distance to other pairs (that is, excluding pairs between the same objects
xij[t,i,i]
). The way I implemented this was first changing the values of these indices to NaN and then usingnp.nanmean()
:xij[...,np.arange(M), np.arange(M)] = np.nan mean = np.nanmean(xij, axis = -1)
- However, changing all these values to
np.nan
becomes a bottleneck in my program and it seems to me that maybe is not necessary. Is there a faster alternative? I see there is an argumentwhere
innp.mean
to choose the elements to include in the calculation as a boolen array. I wonder if you could create this array more efficiently than using theNan
trick I implemented. Or alternatively, maybe using masked arrays? Although I am not familiar with them.
Solution
You could sum, subtract the diagonal, and divide by M-1:
meanDistance = (np.sum(xij, axis = -1) - np.diagonal(xij, axis1=-2, axis2=-1)) / (M - 1)
Demo results:
(sum-diag) / (M-1):
time in seconds: 0.03786587715148926
t=0 first three means: [5.42617836 5.03198446 5.67675881]
nanmean:
time in seconds: 0.18410110473632812
t=0 first three means: [5.42617836 5.03198446 5.67675881]
Demo code (Try it online!):
import numpy as np
from time import time
N = 10**7
M = 10
xij = np.random.uniform(0, 10, N).reshape(int(N/M**2), M, M)
print('(sum-diag) / (M-1):')
t0 = time()
meanDistance = (np.sum(xij, axis = -1) - np.diagonal(xij, axis1=-2, axis2=-1)) / (M - 1)
print(' time in seconds:', time() - t0)
print(' t=0 first three means:', meanDistance[0,:3])
print()
print('nanmean:')
t0 = time()
xij[...,np.arange(M), np.arange(M)] = np.nan
meanDistance = np.nanmean(xij, axis = -1)
print(' time in seconds:', time() - t0)
print(' t=0 first three means:', meanDistance[0,:3])
Answered By - don't talk just code
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.