Last active
July 23, 2017 08:27
-
-
Save Coderx7/3860ee9a7beadccfbd7cd48f39df77af to your computer and use it in GitHub Desktop.
confusionMatrix precision recall F1Score now with caffe classifier which allows for more preprocessing options!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Seyyed Hossein Hasan Pour | |
# [email protected] | |
# 7/3/2016 | |
# Added Recall/Precision/F1-Score as well | |
# 01/03/2017 | |
#info: | |
#if on windows, one can use these command in a batch file and ease him/her self | |
#REM Calculating Confusing Matrix | |
#python confusionMatrix_convnet_test.py --proto cifar10_deploy.prototxt --model cifar10_.caffemodel --mean mean.binaryproto --lmdb cifar10_test_lmdb | |
#pause | |
import sys | |
import caffe | |
import numpy as np | |
import lmdb | |
import argparse | |
from collections import defaultdict | |
from sklearn.metrics import classification_report | |
from sklearn.metrics import confusion_matrix | |
import matplotlib.pyplot as plt | |
import itertools | |
def flat_shape(x): | |
"Returns x without singleton dimension, eg: (1,28,28) -> (28,28)" | |
return np.reshape(x,x.shape) | |
def plot_confusion_matrix(cm #confusion matrix | |
,classes | |
,normalize=False | |
,title='Confusion matrix' | |
,cmap=plt.cm.Blues): | |
""" | |
This function prints and plots the confusion matrix. | |
Normalization can be applied by setting `normalize=True`. | |
""" | |
plt.imshow(cm, interpolation='nearest', cmap=cmap) | |
plt.title(title) | |
plt.colorbar() | |
tick_marks = np.arange(len(classes)) | |
plt.xticks(tick_marks, classes, rotation=45) | |
plt.yticks(tick_marks, classes) | |
if normalize: | |
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | |
print("confusion matrix is normalized!") | |
#print(cm) | |
thresh = cm.max() / 2. | |
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): | |
plt.text(j, i, cm[i, j], | |
horizontalalignment="center", | |
color="white" if cm[i, j] > thresh else "black") | |
plt.tight_layout() | |
plt.ylabel('True label') | |
plt.xlabel('Predicted label') | |
def db_reader(fpath, type='lmdb'): | |
if type == 'lmdb': | |
return lmdb_reader(fpath) | |
else: | |
return leveldb_reader(fpath) | |
def lmdb_reader(fpath): | |
import lmdb | |
lmdb_env = lmdb.open(fpath) | |
lmdb_txn = lmdb_env.begin() | |
lmdb_cursor = lmdb_txn.cursor() | |
for key, value in lmdb_cursor: | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
yield (key, flat_shape(image), label) | |
def leveldb_reader(fpath): | |
import leveldb | |
db = leveldb.LevelDB(fpath) | |
for key, value in db.RangeIter(): | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
yield (key, flat_shape(image), label) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--proto', help='path to the network prototxt file(deploy)', type=str, required=True) | |
parser.add_argument('--model', help='path to your caffemodel file', type=str, required=True) | |
parser.add_argument('--mean', help='path to the mean file(.binaryproto)', type=str, required=True) | |
#group = parser.add_mutually_exclusive_group(required=True) | |
parser.add_argument('--db_type', help='lmdb or leveldb', type=str, required=True) | |
parser.add_argument('--db_path', help='path to your lmdb/leveldb dataset', type=str, required=True) | |
args = parser.parse_args() | |
# Extract mean from the mean image file | |
mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto() | |
f = open(args.mean, 'rb') | |
mean_blobproto_new.ParseFromString(f.read()) | |
mean_image = caffe.io.blobproto_to_array(mean_blobproto_new)[0] | |
f.close() | |
#mu = np.load('mean.npy') | |
# CNN reconstruction and loading the trained weights | |
net = caffe.Classifier(args.proto, args.model, | |
mean=mean_image.mean(1).mean(1), | |
channel_swap=(2,1,0), | |
raw_scale=255, | |
image_dims=(256, 256)) | |
# You may also use set_mode_cpu() if you didnt compile caffe with gpu support | |
caffe.set_mode_gpu() | |
print ("args", vars(args)) | |
reader = db_reader(args.db_path, args.db_type.lower()) | |
predicted_lables=[] | |
true_labels = [] | |
class_names = ['unsafe','safe'] | |
for i, image, label in reader: | |
image_caffe = image.transpose(2,1,0) | |
#print(image_caffe.shape) | |
out = net.predict([image_caffe], oversample=False)#if set to True, should use python2! in python3 setting this to true causes the script to crash! | |
plabel = out[0].argmax() | |
#print(i, 'true-label = ',label,' pre-label = ',plabel, class_names[out[0].argmax()].strip(),' (', out[0][out[0].argmax()] , ')') | |
predicted_lables.append(plabel) | |
true_labels.append(label) | |
print(i,' processed!') | |
print( classification_report(y_true=true_labels, | |
y_pred=predicted_lables, | |
target_names=class_names)) | |
cm = confusion_matrix(y_true=true_labels, | |
y_pred=predicted_lables) | |
print(cm) | |
# Compute confusion matrix | |
cnf_matrix = cm | |
np.set_printoptions(precision=2) | |
# Plot non-normalized confusion matrix | |
plt.figure() | |
plot_confusion_matrix(cnf_matrix, classes=class_names, | |
title='Confusion matrix, without normalization') | |
# Plot normalized confusion matrix | |
plt.figure() | |
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True, | |
title='Normalized confusion matrix') | |
plt.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment