Skip to content

Instantly share code, notes, and snippets.

@yongkangc
Created March 31, 2020 02:05
Show Gist options
  • Save yongkangc/86cc0f9d16bfe53ada84c0daa51afe09 to your computer and use it in GitHub Desktop.
Save yongkangc/86cc0f9d16bfe53ada84c0daa51afe09 to your computer and use it in GitHub Desktop.
#confusionmatrix
from sklearn.metrics import confusion_matrix
import numpy as np
def print_confusion_matrix(confusion_matrix, class_names, figsize = (10,7), fontsize=14, normalize=False):
if normalize:
confusion_matrix = confusion_matrix.astype('float') / confusion_matrix.sum(axis=1)[:, np.newaxis]
fmt = '.2f'
title = 'Normalized Confusion Matrix'
else:
fmt = 'd'
title = 'Confusion Matrix'
df_cm = pd.DataFrame(confusion_matrix, index=class_names, columns=class_names)
fig = plt.figure(figsize=figsize)
heatmap = sns.heatmap(df_cm, annot=True, fmt=fmt)
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize)
heatmap.xaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize)
heatmap.set_ylabel('True label')
heatmap.set_xlabel('Predicted label')
heatmap.set_title(title)
return fig
conf_mat = confusion_matrix(y_test, rf.predict(x_test))
# readable labels
labels = [cat_map[encoder_mapping[label]] for label in sorted(encoder_mapping.keys())]
ax = print_confusion_matrix(conf_mat, labels, normalize=True)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment