Issue
I'm trying to implement fast routine to calculate array of energies and find the smallest calculated value and its index. Here is my code that is working fine:
@jit
def findMinEnergy(x):
def calcEnergy(a):
return a*a # very simplified body, it is actually 15 lines of code
energies = vmap(calcEnergy, in_axes=(0))(x)
idx = energies.argmin(axis=0)
minenrgy = energies[idx]
return idx, minenrgy
I wonder if it is possible to not use the (separate) argmin call, but return the min calculated energy value and it's index from the vmap (similar like other aggregate functions work, e.g. jax.sum)? I hope that it could be more efficient.
Solution
If you JIT-compile your current approach, you should find that it's as efficient as doing something more sophisticated.
Looking at the implementation of argmin
, you'll see that it computes both the value and the index before returning only the index: https://github.com/google/jax/blob/jax-v0.4.18/jax/_src/lax/lax.py#L3892-L3914
If you want, you could follow this implementation and define a function using lax.reduce
that returns both these values in a single pass:
import jax
import jax.numpy as jnp
@jax.jit
def min_and_argmin_onepass(x):
# This only works for 1D float arrays, but you could generalize it.
assert x.ndim == 1
assert jnp.issubdtype(x.dtype, jnp.floating)
def reducer(op_val_index, acc_val_index):
op_val, op_index = op_val_index
acc_val, acc_index = acc_val_index
pick_op_val = (op_val < acc_val) | jnp.isnan(op_val)
pick_op_index = pick_op_val | ((op_val == acc_val) & (op_index < acc_index))
return (jnp.where(pick_op_val, op_val, acc_val),
jnp.where(pick_op_index, op_index, acc_index))
indices = jnp.arange(len(x))
return jax.lax.reduce((x, indices), (jnp.inf, 0), reducer, (0,))
Testing this, we see it matches the output of the less sophisticated approach:
@jax.jit
def min_and_argmin(x):
i = jnp.argmin(x)
return x[i], i
x = jax.random.uniform(jax.random.key(0), (1000000,))
print(min_and_argmin_onepass(x))
# (Array(9.536743e-07, dtype=float32), Array(24430, dtype=int32))
print(min_and_argmin(x))
# (Array(9.536743e-07, dtype=float32), Array(24430, dtype=int32))
If you compare the runtime of the two, you'll see comparable runtimes:
%timeit jax.block_until_ready(min_and_argmin_onepass(x))
# 2.17 ms ± 68.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jax.block_until_ready(min_and_argmin(x))
# 2.07 ms ± 66.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
The jax.jit
decorator here means that the compiler optimizes the sequence of operations in the less sophisticated approach, and the result is that you don't gain much advantage from trying to express things more cleverly. Given this, I think your best option is to stick with your original code rather than trying to out-optimize the XLA compiler.
Answered By - jakevdp
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.