Issue
I'm doing a 10-fold validation and I need to see how the accuracy of each class changes. I managed to create a DataFrame like this:
Snippet:
chars = []
for i in range(0, int(classes) + 1):
row = []
for j in range(0, int(classes) + 1):
row.append(str(round(means[i, j], 3)) + " +/- " + str(round(stds[i, j], 3)))
chars.append(row)
con_mat_df = pd.DataFrame(chars, index=classes_list, columns=classes_list)
0 1 ... 14 15
0 100.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
1 0.49 +/- 0.703 98.53 +/- 1.416 ... 0.0 +/- 0.0 0.0 +/- 0.0
2 0.0 +/- 0.0 0.12 +/- 0.36 ... 0.0 +/- 0.0 0.0 +/- 0.0
3 0.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
4 0.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
5 0.55 +/- 0.905 0.14 +/- 0.42 ... 0.0 +/- 0.0 0.0 +/- 0.0
6 0.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
7 0.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
8 0.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
9 0.62 +/- 1.318 0.2 +/- 0.6 ... 0.0 +/- 0.0 0.0 +/- 0.0
10 0.65 +/- 0.927 0.24 +/- 0.265 ... 0.0 +/- 0.0 0.0 +/- 0.0
11 1.02 +/- 1.558 0.0 +/- 0.0 ... 0.0 +/- 0.0 1.36 +/- 1.482
12 0.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
13 0.32 +/- 0.96 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
14 0.78 +/- 1.191 0.0 +/- 0.0 ... 98.96 +/- 1.274 0.0 +/- 0.0
15 0.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 94.78 +/- 6.884
[16 rows x 16 columns]
Now I just want to be able to plot it as in the example below. I'd like to know how to do this. If I use sns.heatmap
it will throw an error (TypeError: ufunc 'isnan' not supported for the input types...
). Any ideas? Thanks.
Solution
So the easiest way I found was this (cm is the array of means and cms is the array of standard deviations):
def plot_confusion_matrix(cm, cms, classes,
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, '{0:.2f}'.format(cm[i, j]) + '\n$\pm$' + '{0:.2f}'.format(cms[i, j]),
horizontalalignment="center",
verticalalignment="center", fontsize=7,
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(means, stds, classes=classes_list)
Answered By - Giorgio Luigi Morales Luna
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.