Issue
I created a multi-class classifier and now I want to show the confusion matrix and accuracies per class in a clean way.
I already found a function in sklearn that gives me the possibility to show the confusion matrix: sklearn.metrics.plot_confusion_matrix, but I do not see a way to add an extra column where I can put the accuracy per class/row.
This is an example on how its possible to plot the confusion matrix:
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.metrics import plot_confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = SVC(random_state=0)
clf.fit(X_train, y_train)
plot_confusion_matrix(clf, X_test, y_test)
plt.show()
In the following picture, I drawed something in paint to show what I mean by "Add extra column":
Is there a way to do change this example and add the extra column? Or are there other libraries which support what I want to do?
Solution
It doesn't look like anything does this out-of-the-box, so I wrote one:
def plot_class_accuracies(plotted_cm, axis, display_labels=None, cmap="viridis"):
"""
plotted_cm : instance of `ConfusionMatrixDisplay`
Result of `sklearn.metrics.plot_confusion_matrix`
axis : matplotlib `AxesSubplot`
Result of `fig, (ax1, ax2) = plt.subplots(1, 2)`
display_labels : list of labels or None
Human-readable class names
cmap : colormap, optional
Optional colormap
"""
cmatrix = plotted_cm.confusion_matrix
normalized_cmatrix = np.diag(cmatrix) / np.sum(cmatrix, axis=1)
n_classes = len(normalized_cmatrix)
cmap_min, cmap_max = plotted_cm.im_.cmap(0), plotted_cm.im_.cmap(256)
thresh = (normalized_cmatrix.max() + normalized_cmatrix.min()) / 2.0
if display_labels is None:
labels = np.arange(n_classes)
else:
labels = display_labels
axis.imshow(
normalized_cmatrix.reshape(n_classes, 1),
interpolation="nearest",
cmap=cmap,
)
for i, value in enumerate(normalized_cmatrix):
color = cmap_min if value > thresh else cmap_max
axis.text(0, i, format(value, ".2g"), ha="center", va="center", color=color)
axis.set(
yticks=np.arange(len(normalized_cmatrix)),
ylabel="True label",
xlabel="Class accuracy",
yticklabels=labels,
)
axis.tick_params(
axis="x", bottom=False, labelbottom=False,
)
axis.set_ylim((len(normalized_cmatrix) - 0.5, -0.5))
Assuming this is in a file cmatrix.py
:
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import plot_confusion_matrix
# Import `plot_class_accuracies` from `cmatrix.py`
from cmatrix import plot_class_accuracies
if __name__ == "__main__":
class ExampleClassifier(LogisticRegression):
def __init__(self):
self.classes_ = None
def predict(self, X_test):
self.classes_ = np.unique(X_test)
return X_test
X_test = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 2])
y_test = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3])
fig, (ax1, ax2) = plt.subplots(1, 2)
clf = ExampleClassifier()
disp = plot_confusion_matrix(
clf, X_test, y_test, ax=ax1, cmap=plt.cm.Blues, normalize="true"
)
plot_class_accuracies(disp, ax2, cmap=plt.cm.Blues)
plt.show()
Result:
And here's an example based on the example from the Confusion Matrix example from the sklearn documentation:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import plot_confusion_matrix
from cmatrix import plot_class_accuracies
iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
classifier = svm.SVC(kernel='linear', C=0.01).fit(X_train, y_train)
fig, (ax1, ax2) = plt.subplots(1, 2)
disp = plot_confusion_matrix(classifier, X_test, y_test,
display_labels=class_names,
ax=ax1,
cmap=plt.cm.Blues)
plot_class_accuracies(disp, ax2, display_labels=class_names, cmap=plt.cm.Blues)
plt.show()
Result:
Answered By - Alexander L. Hayes
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.