Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save Coderx7/3860ee9a7beadccfbd7cd48f39df77af to your computer and use it in GitHub Desktop.
Save Coderx7/3860ee9a7beadccfbd7cd48f39df77af to your computer and use it in GitHub Desktop.
confusionMatrix precision recall F1Score now with caffe classifier which allows for more preprocessing options!
# Seyyed Hossein Hasan Pour
# [email protected]
# 7/3/2016
# Added Recall/Precision/F1-Score as well
# 01/03/2017
#info:
#if on windows, one can use these command in a batch file and ease him/her self
#REM Calculating Confusing Matrix
#python confusionMatrix_convnet_test.py --proto cifar10_deploy.prototxt --model cifar10_.caffemodel --mean mean.binaryproto --lmdb cifar10_test_lmdb
#pause
import sys
import caffe
import numpy as np
import lmdb
import argparse
from collections import defaultdict
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import itertools
def flat_shape(x):
"Returns x without singleton dimension, eg: (1,28,28) -> (28,28)"
return np.reshape(x,x.shape)
def plot_confusion_matrix(cm #confusion matrix
,classes
,normalize=False
,title='Confusion matrix'
,cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("confusion matrix is normalized!")
#print(cm)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
def db_reader(fpath, type='lmdb'):
if type == 'lmdb':
return lmdb_reader(fpath)
else:
return leveldb_reader(fpath)
def lmdb_reader(fpath):
import lmdb
lmdb_env = lmdb.open(fpath)
lmdb_txn = lmdb_env.begin()
lmdb_cursor = lmdb_txn.cursor()
for key, value in lmdb_cursor:
datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(value)
label = int(datum.label)
image = caffe.io.datum_to_array(datum).astype(np.uint8)
yield (key, flat_shape(image), label)
def leveldb_reader(fpath):
import leveldb
db = leveldb.LevelDB(fpath)
for key, value in db.RangeIter():
datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(value)
label = int(datum.label)
image = caffe.io.datum_to_array(datum).astype(np.uint8)
yield (key, flat_shape(image), label)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--proto', help='path to the network prototxt file(deploy)', type=str, required=True)
parser.add_argument('--model', help='path to your caffemodel file', type=str, required=True)
parser.add_argument('--mean', help='path to the mean file(.binaryproto)', type=str, required=True)
#group = parser.add_mutually_exclusive_group(required=True)
parser.add_argument('--db_type', help='lmdb or leveldb', type=str, required=True)
parser.add_argument('--db_path', help='path to your lmdb/leveldb dataset', type=str, required=True)
args = parser.parse_args()
# Extract mean from the mean image file
mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto()
f = open(args.mean, 'rb')
mean_blobproto_new.ParseFromString(f.read())
mean_image = caffe.io.blobproto_to_array(mean_blobproto_new)[0]
f.close()
#mu = np.load('mean.npy')
# CNN reconstruction and loading the trained weights
net = caffe.Classifier(args.proto, args.model,
mean=mean_image.mean(1).mean(1),
channel_swap=(2,1,0),
raw_scale=255,
image_dims=(256, 256))
# You may also use set_mode_cpu() if you didnt compile caffe with gpu support
caffe.set_mode_gpu()
print ("args", vars(args))
reader = db_reader(args.db_path, args.db_type.lower())
predicted_lables=[]
true_labels = []
class_names = ['unsafe','safe']
for i, image, label in reader:
image_caffe = image.transpose(2,1,0)
#print(image_caffe.shape)
out = net.predict([image_caffe], oversample=False)#if set to True, should use python2! in python3 setting this to true causes the script to crash!
plabel = out[0].argmax()
#print(i, 'true-label = ',label,' pre-label = ',plabel, class_names[out[0].argmax()].strip(),' (', out[0][out[0].argmax()] , ')')
predicted_lables.append(plabel)
true_labels.append(label)
print(i,' processed!')
print( classification_report(y_true=true_labels,
y_pred=predicted_lables,
target_names=class_names))
cm = confusion_matrix(y_true=true_labels,
y_pred=predicted_lables)
print(cm)
# Compute confusion matrix
cnf_matrix = cm
np.set_printoptions(precision=2)
# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names,
title='Confusion matrix, without normalization')
# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
title='Normalized confusion matrix')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment