Created
November 29, 2017 17:57
-
-
Save morganmcg1/d8ba1963c83cbfc97e497460362a4481 to your computer and use it in GitHub Desktop.
Plot a confusion, with the option to normalise the values
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_confusion_matrix(cm, | |
classes, | |
normalize, | |
title): | |
''' | |
PARAMETERS: | |
- cm: SKL.METRICS confustion matrix | |
- Classes: the labels used in the classification | |
- normalize=False : (True will normalise the values, can be useful for large sample with few outliers) | |
- title : should be set to: 'Confusion matrix' | |
RETURN: | |
- Plot of confusion matrix for multi-class classification | |
''' | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import itertools | |
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) | |
plt.title(title) | |
plt.colorbar() | |
tick_marks = np.arange(len(classes)) | |
plt.xticks(tick_marks, classes, rotation=45) | |
plt.yticks(tick_marks, classes) | |
if normalize: | |
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | |
print("Normalized confusion matrix") | |
else: | |
print('Confusion matrix, without normalization') | |
print(cm) | |
thresh = cm.max() / 2. | |
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): | |
plt.text(j, i, cm[i, j], | |
horizontalalignment="center", | |
color="white" if cm[i, j] > thresh else "black") | |
plt.tight_layout() | |
plt.ylabel('True label') | |
plt.xlabel('Predicted label') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment