Issue
I need to find the minimum of a cost function with several thousand variables. The cost function is simply a least squares calculation and can be computed easily and quickly with numpy vectorization. Despite this, the optimization still takes a painstakingly long time. My guess is that the slow runtime is occuring in SciPy's minimizer rather than my cost function. How can I change the parameters of SciPy's minimizer to speed up the runtime?
Sample Code:
import numpy as np
from scipy.optimize import minimize
# random data
x = np.random.randn(100, 75)
# initial weights guess
startingWeights = np.ones(shape=(100, 75))
# random y vector
y = np.random.randn(100)
def costFunction(weights):
# reshapes flattened weights into 2d matrix
weights = np.reshape(weights, newshape=(100, 75))
# weighted row-wise sum
weighted = np.sum(x * weights, axis=1)
# squared residuals
residualsSquared = (y - weighted) ** 2
return np.sum(residualsSquared)
result = minimize(costFunction, startingWeights.flatten())
Solution
As already noted in the comments, it's highly recommended to provide the exact objective gradient for a large problem with N = 100*75 = 7500
variables. Without a provided gradient, it will be approximated by finite differences and by means of the approx_derivative
function. However, finite differences are error-prone and computationally expensive due to the fact that each evaluation of the gradient requires 2*N
evaluations of the objective function (without caching).
This can be easily illustrated by timing the objective and the approximated gradient:
In [7]: %timeit costFunction(startingWeights.flatten())
23.5 µs ± 2.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [8]: from scipy.optimize._numdiff import approx_derivative
In [9]: %timeit approx_derivative(costFunction, startingWeights.flatten())
633 ms ± 33.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Consequently, each gradient evaluation takes more than half of a second on my machine! A more efficient approach to evaluate the gradient is algorithmic differentiation. Using the JAX library, it's quite easy:
import jax.numpy as jnp
from jax import jit, value_and_grad
def costFunction(weights):
# reshapes flattened weights into 2d matrix
weights = jnp.reshape(weights, newshape=(100, 75))
# weighted row-wise sum
weighted = jnp.sum(x * weights, axis=1)
# squared residuals
residualsSquared = (y - weighted) ** 2
return jnp.sum(residualsSquared)
# create the derivatives
obj_and_grad = jit(value_and_grad(costFunction))
Here, value_and_grad
creates a function that evaluations the objective
and the gradient and returns both, i.e. obj_value, grad_values = obj_and_grad(x0)
. So let's time this function:
In [12]: %timeit obj_and_grad(startingWeights.flatten())
132 µs ± 6.62 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Thus, we now evaluate the objective and the gradient nearly 5000 times faster than before. Finally, we can tell minimize
that our objective function returns the objective and the gradient by setting jac=True
. So
minimize(obj_and_grad, x0=startingWeights.flatten(), jac=True)
should significantly speed up the optimization.
PS: You can also try the state-of-the-art Ipopt solver interfaced by the cyipopt package. It also provides a scipy-like interface similar to scipy.optimize.minimize.
Answered By - joni
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.