Last active
September 3, 2019 08:22
-
-
Save Coderx7/3b02f6a8eb61110d7d0876e75933f98d to your computer and use it in GitHub Desktop.
Caffe confusion matrix, precision and recall and F1 Score script!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#in the name of God the most compassionate the most merciful | |
# added mean subtraction so that, the accuracy can be reported accurately just like caffe when doing a mean subtraction | |
# Seyyed Hossein Hasan Pour | |
# [email protected] | |
# 7/3/2016 | |
# Added Recall/Precision/F1-Score as well | |
# 01/03/2017 | |
# Added batch processing, not what used to take a minute or so, takes only several seconds! | |
# 07/25/2017 | |
#info: | |
#if on windows, one can use these command in a batch file and ease him/her self | |
#REM Calculating Confusing Matrix | |
#python confusionMatrix_convnet_test.py --proto cifar10_deploy.prototxt --model cifar10.caffemodel --mean mean.binaryproto --lmdb cifar10_test_lmdb | |
#pause | |
import sys | |
import caffe | |
import numpy as np | |
import lmdb | |
import argparse | |
from collections import defaultdict | |
from sklearn.metrics import classification_report | |
from sklearn.metrics import confusion_matrix | |
import matplotlib.pyplot as plt | |
import itertools | |
from sklearn.metrics import roc_curve, auc | |
import random | |
def flat_shape(x): | |
"Returns x without singleton dimension, eg: (1,28,28) -> (28,28)" | |
return np.reshape(x,x.shape) | |
def plot_confusion_matrix(cm #confusion matrix | |
,classes | |
,normalize=False | |
,title='Confusion matrix' | |
,cmap=plt.cm.Blues): | |
""" | |
This function prints and plots the confusion matrix. | |
Normalization can be applied by setting `normalize=True`. | |
""" | |
plt.imshow(cm, interpolation='nearest', cmap=cmap) | |
plt.title(title) | |
plt.colorbar() | |
tick_marks = np.arange(len(classes)) | |
plt.xticks(tick_marks, classes, rotation=45) | |
plt.yticks(tick_marks, classes) | |
if normalize: | |
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | |
print("confusion matrix is normalized!") | |
#print(cm) | |
thresh = cm.max() / 2. | |
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): | |
plt.text(j, i, cm[i, j], | |
horizontalalignment="center", | |
color="white" if cm[i, j] > thresh else "black") | |
plt.tight_layout() | |
plt.ylabel('True label') | |
plt.xlabel('Predicted label') | |
def db_reader(fpath, type='lmdb'): | |
if type == 'lmdb': | |
return lmdb_reader(fpath) | |
else: | |
return leveldb_reader(fpath) | |
def lmdb_reader(fpath): | |
import lmdb | |
lmdb_env = lmdb.open(fpath) | |
lmdb_txn = lmdb_env.begin() | |
lmdb_cursor = lmdb_txn.cursor() | |
for key, value in lmdb_cursor: | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
yield (key, flat_shape(image), label) | |
def leveldb_reader(fpath): | |
import leveldb | |
db = leveldb.LevelDB(fpath) | |
for key, value in db.RangeIter(): | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
yield (key, flat_shape(image), label) | |
def ShowInfo(correct, count, true_labels, predicted_lables, class_names, misclassified, | |
filename='misclassifieds.txt', | |
title='Receiver Operating Characteristic_ROC', | |
title_CM='Confusion matrix, without normalization', | |
title_CM_N='Normalized confusion matrix'): | |
sys.stdout.write("\rAccuracy: %.1f%%" % (100.*correct/count)) | |
sys.stdout.flush() | |
print(", %i/%i corrects" % (correct, count)) | |
np.savetxt(filename,misclassified,fmt="%s") | |
print( classification_report(y_true=true_labels, | |
y_pred=predicted_lables, | |
target_names=class_names)) | |
cm = confusion_matrix(y_true=true_labels, | |
y_pred=predicted_lables) | |
print(cm) | |
# print(title) | |
# false_positive_rate, true_positive_rate, thresholds = roc_curve(true_labels, predicted_lables) | |
# roc_auc = auc(false_positive_rate, true_positive_rate) | |
# plt.title('Receiver Operating Characteristic_ROC 1') | |
# plt.plot(false_positive_rate, true_positive_rate, 'b', | |
# label='AUC = %0.2f'% roc_auc) | |
# plt.legend(loc='lower right') | |
# plt.plot([0,1],[0,1],'r--') | |
# plt.xlim([-0.1,1.2]) | |
# plt.ylim([-0.1,1.2]) | |
# plt.ylabel('True Positive Rate') | |
# plt.xlabel('False Positive Rate') | |
# plt.show() | |
# Compute confusion matrix | |
cnf_matrix = cm | |
np.set_printoptions(precision=2) | |
# Plot non-normalized confusion matrix | |
plt.figure() | |
plot_confusion_matrix(cnf_matrix, classes=class_names, | |
title=title_CM) | |
# Plot normalized confusion matrix | |
plt.figure() | |
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True, | |
title=title_CM_N) | |
plt.show() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--proto', help='path to the network prototxt file(deploy)', type=str, required=True) | |
parser.add_argument('--model', help='path to your caffemodel file', type=str, required=True) | |
parser.add_argument('--mean', help='path to the mean file(.binaryproto)', type=str, required=True) | |
#group = parser.add_mutually_exclusive_group(required=True) | |
parser.add_argument('--db_type', help='lmdb or leveldb', type=str, required=True) | |
parser.add_argument('--db_path', help='path to your lmdb/leveldb dataset', type=str, required=True) | |
args = parser.parse_args() | |
# Extract mean from the mean image file | |
mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto() | |
f = open(args.mean, 'rb') | |
mean_blobproto_new.ParseFromString(f.read()) | |
mean_image = caffe.io.blobproto_to_array(mean_blobproto_new) | |
f.close() | |
#mu = np.load('mean.npy') | |
#mu = np.array([ 104, 117, 123])#imagenet mean | |
caffe.set_mode_gpu() | |
#CNN reconstruction and loading the trained weights | |
#print ("args", vars(args)) | |
predicted_lables=[] | |
true_labels = [] | |
misclassified =[] | |
#class_names = ['unsafe','safe'] | |
class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck'] | |
count=0 | |
correct = 0 | |
idx=0 | |
batch=[] | |
plabe_ls=[] | |
batch_size = 50 | |
net1 = caffe.Net(args.proto, args.model, caffe.TEST) | |
transformer = caffe.io.Transformer({'data': net1.blobs['data'].data.shape}) | |
#transformer.set_transpose('data', (2,0,1)) | |
transformer.set_mean('data', mean_image[0]) | |
#transformer.set_raw_scale('data', 1) | |
transformer.set_channel_swap('data', (2,1,0)) | |
net1.blobs['data'].reshape(batch_size, 3,32, 32) | |
data_blob_shape = net1.blobs['data'].data.shape | |
data_blob_shape = list(data_blob_shape) | |
#net1.blobs['data'].reshape(batch_size, data_blob_shape[1], data_blob_shape[2], data_blob_shape[3]) | |
i=0 | |
#check and see if its lmdb or leveldb | |
if(args.db_type.lower() == 'lmdb'): | |
lmdb_env = lmdb.open(args.db_path) | |
lmdb_txn = lmdb_env.begin() | |
lmdb_cursor = lmdb_txn.cursor() | |
for key, value in lmdb_cursor: | |
count += 1 | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
#key,image,label | |
#buffer n image | |
if(count%2000==0): | |
print('count: ',count) | |
if(i < batch_size): | |
i+=1 | |
inf= key,image,label | |
batch.append(inf) | |
#print(key) | |
if(i >= batch_size): | |
#process n image | |
ims=[] | |
images = [image_info for image_info in batch ] | |
for x in range(len(batch)): | |
#using transformer here decreases performance! | |
ims.append(batch[x][1]-mean_image[0]) #ims.append(transformer.preprocess('data',batch[x][1] )) | |
net1.blobs['data'].data[...] = ims[:] | |
out_1 = net1.forward() | |
plabe_ls = out_1['prob']#.argmax(axis=0) | |
plbl = np.asarray(plabe_ls) | |
#print(plbl) | |
#print(plbl.argmax(axis=1)) | |
plbl = plbl.argmax(axis=1) | |
for j in range(len(batch)): | |
if (plbl[j] == batch[j][2]): | |
correct+=1 | |
else: | |
misclassified.append(batch[j][0]) | |
predicted_lables.append(plbl[j]) | |
true_labels.append(batch[j][2]) | |
batch.clear() | |
i=0 | |
ShowInfo(correct,count, true_labels, predicted_lables, class_names, misclassified, | |
filename='misclassifieds.txt', | |
title='Receiver Operating Characteristic_ROC' ) | |
else:#leveldb | |
import leveldb | |
db = leveldb.LevelDB(args.db_path) | |
for key, value in db.RangeIter(): | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
#key,image,label | |
#buffer n image | |
#print('count: ',count) | |
if(i < batch_size): | |
i+=1 | |
inf= key,image,label | |
batch.append(inf) | |
#print(key) | |
if(i >= batch_size): | |
#process n image | |
ims=[] | |
images = [image_info for image_info in batch ] | |
for x in range(len(batch)): | |
ims.append(batch[x][1]-mean_image[0]) #ims.append(transformer.preprocess('data',batch[x][1])) | |
net1.blobs['data'].data[...] = ims[:] | |
out_1 = net1.forward() | |
plabe_ls = out_1['prob']#.argmax(axis=0) | |
plbl = np.asarray(plabe_ls) | |
#print(plbl) | |
#print(plbl.argmax(axis=1)) | |
plbl = plbl.argmax(axis=1) | |
for j in range(len(batch)): | |
if (plbl[j] == batch[j][2]): | |
correct+=1 | |
else: | |
misclassified.append(batch[j][0]) | |
predicted_lables.append(plbl[j]) | |
true_labels.append(batch[j][2]) | |
batch.clear() | |
i=0 | |
ShowInfo(correct,count, true_labels, predicted_lables, class_names, misclassified, | |
filename='misclassifieds.txt', | |
title='Receiver Operating Characteristic_ROC' ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment