Issue
I am currently using python to train a random forest model. I initially tried to compute the ROC curve representations as follows:
import scikitplot as skplt
from sklearn.metrics import RocCurveDisplay
import sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn import metrics
model2_bert = RandomForestClassifier(bootstrap=False, max_depth=None, max_features='auto', min_samples_leaf=5, min_samples_split=5, n_estimators=50)
rf1 = model2_bert.fit(X_train, y_train)
y_hat = rf1.predict(X_test)
ax = plt.gca()
rfc_disp = RocCurveDisplay.from_estimator(rf1, X_test, y_test, ax=ax)
plt.show()
This gives me the figure :
As I was getting some AUC=1 for some specifications, I decided to try to calculate this in other ways.
fpr, tpr, thresholds = metrics.roc_curve(y_test, y_hat)
roc_auc = metrics.auc(fpr, tpr)
roc_auc
# method I: plt
plt.title('Receiver Operating Characteristic')
plt.plot(fpr, tpr, 'b', label='AUC = %0.2f' % roc_auc)
plt.legend(loc='lower right')
plt.plot([0, 1], [0, 1], 'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.show()
And the third method :
y_prob = rf.predict_proba(X_test)
rf1 = model2_bert.fit(X_train, y_train)
skplt.metrics.plot_roc_curve(y_test, y_prob)
metrics.plot_roc_curve(rf1, X_test, y_test)
plt.show()
It seems like the last two coincide and are based on the calculated AUC. What is wrong with the first graph, then?
PS: This is the confusion matris so AUC=0.97 seems to high for me. I even got AUC=1 in the first figure for some specifications...
Solution
I think the problem is that y_hat = rf1.predict(X_test)
is returning binary classification output (0 and 1). For the ROC AUC you need a probability or score.
Instead, you should use predict_proba
:
y_prob = rf1.predict_proba(X_test)
fpr, tpr, thresholds = metrics.roc_curve(y_test, y_prob )
....
<rest of your code>
Answered By - Cristobal Sarome
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.