Created
April 19, 2016 21:13
-
-
Save daviddao/ccb2aa48c7b534cbcd91e56345430d06 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Confusion Matrix (improved version with total numbers) | |
def plot_confusion_matrix(conf_arr, labels, title='Confusion matrix', cmap=plt.cm.Blues): | |
#plt.imshow(cm, interpolation='nearest', cmap=cmap) | |
norm_conf = [] | |
for i in conf_arr: | |
a = 0 | |
tmp_arr = [] | |
a = sum(i, 0) | |
for j in i: | |
tmp_arr.append(float(j)/float(a)) | |
norm_conf.append(tmp_arr) | |
fig = plt.figure() | |
plt.clf() | |
ax = fig.add_subplot(111) | |
ax.set_aspect(1) | |
res = ax.imshow(np.array(norm_conf), cmap=cmap, | |
interpolation='nearest') | |
width = len(conf_arr) | |
height = len(conf_arr[0]) | |
for x in xrange(width): | |
for y in xrange(height): | |
if conf_arr[x][y] != 0: | |
ax.annotate("%.2f" % conf_arr[x][y], xy=(y, x), | |
horizontalalignment='center', | |
verticalalignment='center') | |
plt.title(title) | |
plt.colorbar(res) | |
tick_marks = np.arange(len(labels)) | |
plt.xticks(tick_marks, labels, rotation=45) | |
plt.yticks(tick_marks, labels) | |
plt.tight_layout() | |
plt.ylabel('True label') | |
plt.xlabel('Predicted label') | |
def ConfusionMatrix(y_test, y_pred, labels): | |
from sklearn.metrics import confusion_matrix | |
# Compute confusion matrix | |
folds = 5 # like classification report | |
cm = confusion_matrix(y_test, y_pred) | |
nObjects = cm.sum() | |
misRate = float(nObjects - np.diag(cm).sum()) * 100 / nObjects | |
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | |
np.set_printoptions(precision=2) | |
print('Confusion matrix (Classification Accuracy: %3.2f%%)' % (100 - misRate)) | |
plot_confusion_matrix(cm_normalized, labels, title='Confusion matrix (Classification Accuracy: %3.2f%%)' % (100 - misRate)) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment