Issue
I have a dataset consisting of both numeric and categorical data and I want to predict adverse outcomes for patients based on their medical characteristics. I defined a prediction pipeline for my dataset like so:
X = dataset.drop(columns=['target'])
y = dataset['target']
# define categorical and numeric transformers
numeric_transformer = Pipeline(steps=[
('knnImputer', KNNImputer(n_neighbors=2, weights="uniform")),
('scaler', StandardScaler())])
categorical_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
('onehot', OneHotEncoder(handle_unknown='ignore'))])
# dispatch object columns to the categorical_transformer and remaining columns to numerical_transformer
preprocessor = ColumnTransformer(transformers=[
('num', numeric_transformer, selector(dtype_exclude="object")),
('cat', categorical_transformer, selector(dtype_include="object"))
])
# Append classifier to preprocessing pipeline.
# Now we have a full prediction pipeline.
clf = Pipeline(steps=[('preprocessor', preprocessor),
('classifier', LogisticRegression())])
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
clf.fit(X_train, y_train)
print("model score: %.3f" % clf.score(X_test, y_test))
However, when running this code, I get the following warning message:
ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
Increase the number of iterations (max_iter) or scale the data as shown in:
https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)
model score: 0.988
Can someone explain to me what this warning means? I am new to machine learning so am a little lost as to what I can do to improve the prediction model. As you can see from the numeric_transformer, I scaled the data through standardisation. I am also confused as to how the model score is quite high and whether this is a good or bad thing.
Solution
The warning means what it mainly says: Suggestions to try to make the solver (the algorithm) converges.
lbfgs
stand for: "Limited-memory Broyden–Fletcher–Goldfarb–Shanno Algorithm". It is one of the solvers' algorithms provided by Scikit-Learn Library.
The term limited-memory simply means it stores only a few vectors that represent the gradients approximation implicitly.
It has better convergence on relatively small datasets.
But what is algorithm convergence?
In simple words. If the error of solving is ranging within very small range (i.e., it is almost not changing), then that means the algorithm reached the solution (not necessary to be the best solution as it might be stuck at what so-called "local Optima").
On the other hand, if the error is varying noticeably (even if the error is relatively small [like in your case the score was good], but rather the differences between the errors per iteration is greater than some tolerance) then we say the algorithm did not converge.
Now, you need to know that Scikit-Learn API sometimes provides the user the option to specify the maximum number of iterations the algorithm should take while it's searching for the solution in an iterative manner:
LogisticRegression(... solver='lbfgs', max_iter=100 ...)
As you can see, the default solver in LogisticRegression is 'lbfgs' and the maximum number of iterations is 100 by default.
Final words, please, however, note that increasing the maximum number of iterations does not necessarily guarantee convergence, but it certainly helps!
Update:
Based on your comment below, some tips to try (out of many) that might help the algorithm to converge are:
- Increase the number of iterations: As in this answer;
- Try a different optimizer: Look here;
- Scale your data: Look here;
- Add engineered features: Look here;
- Data pre-processing: Look here - use case and here;
- Add more data: Look here.
Answered By - Yahya
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.