Last active
February 24, 2021 13:48
-
-
Save Coderx7/205651853a248a512256aa21f1d3bec0 to your computer and use it in GitHub Desktop.
Confusion Matrix with Recall, Precision and F1-Score for Caffe
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# Author: SeyyedHossein Hasanpour copyright 2017, license GPLv3. | |
# Seyyed Hossein Hasan Pour: | |
# [email protected] | |
# Changelog: | |
# 2015: | |
# initial code to calculate confusionmatrix by Axel Angel | |
# 7/3/2016:(adding new features-by-hossein) | |
# added mean subtraction so that, the accuracy can be reported accurately just like caffe when doing a mean subtraction | |
# 01/03/2017: | |
# removed old codes and Added Recall/Precision/F1-Score as well | |
# 03/05/2017 | |
# Added ConfusionMatrix which was mistakenly ommited before. | |
#info: | |
#if on windows, one can use these command in a batch file and ease him/her self | |
#REM Calculating Confusing Matrix | |
#python confusionMatrix_convnet_test.py --proto cifar10_deploy_94_68.prototxt --model cifar10_deploy_94_68.caffemodel --mean mean.binaryproto --db_type lmdb --db_path cifar10_test_lmdb | |
# | |
import sys | |
import caffe | |
import numpy as np | |
import lmdb | |
import argparse | |
from collections import defaultdict | |
from sklearn.metrics import classification_report | |
from sklearn.metrics import confusion_matrix | |
import matplotlib.pyplot as plt | |
import itertools | |
def flat_shape(x): | |
"Returns x without singleton dimension, eg: (1,28,28) -> (28,28)" | |
return x.reshape(filter(lambda s: s > 1, x.shape)) | |
def db_reader(fpath, type='lmdb'): | |
if type == 'lmdb': | |
return lmdb_reader(fpath) | |
else: | |
return leveldb_reader(fpath) | |
def plot_confusion_matrix(cm #confusion matrix | |
,classes | |
,normalize=False | |
,title='Confusion matrix' | |
,cmap=plt.cm.Blues): | |
""" | |
This function prints and plots the confusion matrix. | |
Normalization can be applied by setting `normalize=True`. | |
""" | |
plt.imshow(cm, interpolation='nearest', cmap=cmap) | |
plt.title(title) | |
plt.colorbar() | |
tick_marks = np.arange(len(classes)) | |
plt.xticks(tick_marks, classes, rotation=45) | |
plt.yticks(tick_marks, classes) | |
if normalize: | |
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | |
print("confusion matrix is normalized!") | |
#print(cm) | |
thresh = cm.max() / 2. | |
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): | |
plt.text(j, i, cm[i, j], | |
horizontalalignment="center", | |
color="white" if cm[i, j] > thresh else "black") | |
plt.tight_layout() | |
plt.ylabel('True label') | |
plt.xlabel('Predicted label') | |
def lmdb_reader(fpath): | |
import lmdb | |
lmdb_env = lmdb.open(fpath) | |
lmdb_txn = lmdb_env.begin() | |
lmdb_cursor = lmdb_txn.cursor() | |
for key, value in lmdb_cursor: | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
yield (key, flat_shape(image), label) | |
def leveldb_reader(fpath): | |
import leveldb | |
db = leveldb.LevelDB(fpath) | |
for key, value in db.RangeIter(): | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
yield (key, flat_shape(image), label) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--proto', help='path to the network prototxt file(deploy)', type=str, required=True) | |
parser.add_argument('--model', help='path to your caffemodel file', type=str, required=True) | |
parser.add_argument('--mean', help='path to the mean file(.binaryproto)', type=str, required=True) | |
#group = parser.add_mutually_exclusive_group(required=True) | |
parser.add_argument('--db_type', help='lmdb or leveldb', type=str, required=True) | |
parser.add_argument('--db_path', help='path to your lmdb/leveldb dataset', type=str, required=True) | |
args = parser.parse_args() | |
# Extract mean from the mean image file | |
mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto() | |
f = open(args.mean, 'rb') | |
mean_blobproto_new.ParseFromString(f.read()) | |
mean_image = caffe.io.blobproto_to_array(mean_blobproto_new) | |
f.close() | |
# CNN reconstruction and loading the trained weights | |
net = caffe.Net(args.proto, args.model, caffe.TEST) | |
# You may also use set_mode_cpu() if you didnt compile caffe with gpu support | |
caffe.set_mode_gpu() | |
print ("args", vars(args)) | |
reader = db_reader(args.db_path, args.db_type.lower()) | |
predicted_lables=[] | |
true_labels = [] | |
class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck'] | |
for i, image, label in reader: | |
image_caffe = image.reshape(1, *image.shape) | |
#print 'image shape: ',image_caffe.shape | |
out = net.forward_all(data=np.asarray([ image_caffe ])- mean_image) | |
plabel = int(out['prob'][0].argmax(axis=0)) | |
predicted_lables.append(plabel) | |
true_labels.append(label) | |
print(i,' processed!') | |
print( classification_report(y_true=true_labels, | |
y_pred=predicted_lables, | |
target_names=class_names)) | |
cm = confusion_matrix(y_true=true_labels, | |
y_pred=predicted_lables) | |
print(cm) | |
# Compute confusion matrix | |
cnf_matrix = cm | |
np.set_printoptions(precision=2) | |
# Plot non-normalized confusion matrix | |
plt.figure() | |
plot_confusion_matrix(cnf_matrix, classes=class_names, | |
title='Confusion matrix, without normalization') | |
# Plot normalized confusion matrix | |
plt.figure() | |
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True, | |
title='Normalized confusion matrix') | |
plt.show() | |
@flrndttrch: thanks for the note, when I updated the script I forgot to edit that part .
However I see I added the needed information as a comment to the script.
anyway I update the first post now.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
It seems that your command instructions are not correct.
should be:
python confusionMatrix_convnet_test.py --proto cifar10_deploy.prototxt --model cifar10_SimpleNet_xavier_95.26.caffemodel.h5 --mean mean.binaryproto --db_type lmdb --db_path cifar10_test_lmdb
Also I'm getting a KeyError for 'prob' in out:
Traceback (most recent call last): File "confusionMatrix_Recall_Precision_F1Scroe_Caffe.py", line 137, in <module> plabel = int(out['prob'][0].argmax(axis=0)) KeyError: 'prob'
Maybe it has to be 'loss' and argmax()?
EDIT: Yes that seems to work fine
EDIT2: Ok got it. It depends on what you use in top:
layer { name: "loss" type: "Softmax" bottom: "ip2" top: "loss" }