Issue
I would like to perform hyperparamter optimisation for a model I have trained in Scikit-Learn. I want to first use a random search to get an idea of a good area to search in and then follow it up with a grid search. The method of validation I need to use is Leave One Group Out (LOGO).
So something to this effect:
distributions = {
"n_estimators": randint(low=1, high=500),
"criterion": ["squared_error", "absolute_error", "poisson"],
"max_depth": randint(low=1, high=100)
}
random_search = RandomizedSearchCV(
forest_reg,
distributions,
cv=LeaveOneGroupOut(),
groups=group,
scoring="neg_mean_squared_error",
return_train_score=True,
random_state=42,
n_jobs=-1,
n_iter=20
)
random_search.fit(X, y)
Neither RandomizedSearchCV or GridSearchCV offer support for LOGO validation with definition of groups. When I use a method such as cross_val_score() I can send in a chosen cross validation method like so
scores = cross_val_score(
forest_reg,
X,
y,
scoring="neg_mean_squared_error",
cv=LeaveOneGroupOut(),
groups=group,
n_jobs=-1
)
Is there a reason that the same is not supported with either of the hyperparameter search methods? Am I using the API in the wrong way? Is there a way to achieve what I want using sklearn, without cludging something together myself?
Solution
Groups should be passed into the fit()
method when using LeaveOneGroupOut
.
RandomizedSearchCV.fit()
documentation specify that the parameter groups
should be used only in conjunction with a “Group” cv instance such as GroupKFold
or LeaveOneGroupOut
.
See example below:
from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import RandomizedSearchCV, LeaveOneGroupOut
import numpy as np
params = {
"n_estimators": [1, 5, 10],
"max_depth": [2, 5, 10]
}
X, y = make_regression()
groups = np.random.randint(5, size=y.shape)
cv = RandomizedSearchCV(RandomForestRegressor(),
params,
cv=LeaveOneGroupOut()
)
cv.fit(X, y, groups=groups)
Answered By - Antoine Dubuis
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.