Issue
I am using a MultiOutputClassifier()
wrapper from scikit-learn for a multi-label classification task.
clf = MultiOutputClassifier(RandomForestClassifier())
Now I want to use RandomizedSearchCV to find the best parameters for the RandomForestClassifier which is wrapped inside MultiOutputClassifier.
params :
params = {
'n_estimators': [i for i in range(50,225,25)],
'max_depth' : [10,20,30,40,50],
'max_features' : ['auto', 'sqrt', 'log2']
}
But when I am doing the following :
clf = RandomizedSearchCV(clf, params, cv=5, return_train_score=False)
clf.fit(X_train, y_train)
This error shows up :
ValueError: Invalid parameter n_estimators for estimator MultiOutputClassifier(estimator=RandomForestClassifier(bootstrap=True,
ccp_alpha=0.0,
class_weight=None,
criterion='gini',
max_depth=None,
max_features='auto',
max_leaf_nodes=None,
max_samples=None,
min_impurity_decrease=0.0,
min_impurity_split=None,
min_samples_leaf=1,
min_samples_split=2,
min_weight_fraction_leaf=0.0,
n_estimators=100,
n_jobs=None,
oob_score=False,
random_state=42,
verbose=0,
warm_start=False),
n_jobs=None). Check the list of available parameters with `estimator.get_params().keys()`.
So is there any way to pass those params directly into the RandomForestClassifier Wrapped inside MultiOutputClassifier while using RandomizedSearchCV?
Solution
You should add estimator__
to the hyperparameter names:
from sklearn.datasets import make_multilabel_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.multioutput import MultiOutputClassifier
X, y = make_multilabel_classification(n_classes=3, random_state=0)
clf = MultiOutputClassifier(RandomForestClassifier())
params = {
'estimator__n_estimators': [i for i in range(50, 225, 25)],
'estimator__max_depth': [10, 20, 30, 40, 50],
'estimator__max_features': ['auto', 'sqrt', 'log2']
}
clf = RandomizedSearchCV(clf, params, cv=5, return_train_score=False)
clf.fit(X, y)
print(clf.best_params_)
# {'estimator__n_estimators': 75,
# 'estimator__max_features': 'sqrt',
# 'estimator__max_depth': 20}
Answered By - Flavia Giammarino
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.