Last active
September 22, 2017 20:32
-
-
Save Coderx7/ab07928250515ca5e8a55c5139eacdd9 to your computer and use it in GitHub Desktop.
its like the previous script for calculating ConfusionMatrix,Precision,Recall and stuff like before but with more preprocessing,(cropping,mean subtraction, etc)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#in the name of God the most compassionate the most merciful | |
# added mean subtraction so that, the accuracy can be reported accurately just like caffe when doing a mean subtraction | |
# Seyyed Hossein Hasan Pour | |
# [email protected] | |
# 7/3/2016 | |
# Added Recall/Precision/F1-Score as well | |
# 01/03/2017 | |
# added classification using caffe.classifier, which lets you do oversampling (allows for multiple preprocessing options and crops) + the manual center crop | |
# method. its sloppy at the moment. for some reason, caffe.classifier's predict method, provides less accuracy compared to the caffes test accuracy | |
# the manuall center crop (i.e. the one which uses Caffe.Net, works just well and achieves the same accuracy as Caffe's) | |
# 8/17/2017 | |
# Fixed the discrepencies between the accuracy achieved using this script and Caffe's test switch | |
#info: | |
#if on windows, one can use these command in a batch file and ease him/her self | |
#REM Calculating Confusing Matrix | |
#python confusionMatrix_convnet_test.py --proto deploy.prototxt --model model_280500_rmsa.caffemodel --mean mean.binaryproto --lmdb mycustom_test_lmdb | |
#pause | |
import sys | |
import caffe | |
import numpy as np | |
import lmdb | |
import argparse | |
from collections import defaultdict | |
from sklearn.metrics import classification_report | |
from sklearn.metrics import confusion_matrix | |
import matplotlib.pyplot as plt | |
import itertools | |
from sklearn.metrics import roc_curve, auc | |
import random | |
def plot_confusion_matrix(cm #confusion matrix | |
,classes | |
,normalize=False | |
,title='Confusion matrix' | |
,cmap=plt.cm.Blues): | |
""" | |
This function prints and plots the confusion matrix. | |
Normalization can be applied by setting `normalize=True`. | |
""" | |
plt.imshow(cm, interpolation='nearest', cmap=cmap) | |
plt.title(title) | |
plt.colorbar() | |
tick_marks = np.arange(len(classes)) | |
plt.xticks(tick_marks, classes, rotation=45) | |
plt.yticks(tick_marks, classes) | |
if normalize: | |
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | |
print("confusion matrix is normalized!") | |
#print(cm) | |
thresh = cm.max() / 2. | |
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): | |
plt.text(j, i, cm[i, j], | |
horizontalalignment="center", | |
color="white" if cm[i, j] > thresh else "black") | |
plt.tight_layout() | |
plt.ylabel('True label') | |
plt.xlabel('Predicted label') | |
def db_reader(fpath, type='lmdb'): | |
if type == 'lmdb': | |
return lmdb_reader(fpath) | |
else: | |
return leveldb_reader(fpath) | |
def lmdb_reader(fpath): | |
import lmdb | |
lmdb_env = lmdb.open(fpath) | |
lmdb_txn = lmdb_env.begin() | |
lmdb_cursor = lmdb_txn.cursor() | |
for key, value in lmdb_cursor: | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
yield (key, flat_shape(image), label) | |
def leveldb_reader(fpath): | |
import leveldb | |
db = leveldb.LevelDB(fpath) | |
for key, value in db.RangeIter(): | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
yield (key, flat_shape(image), label) | |
def ShowInfo(correct, count, true_labels, predicted_lables, class_names, misclassified, | |
filename='misclassifieds.txt', | |
title='Receiver Operating Characteristic_ROC', | |
title_CM='Confusion matrix, without normalization', | |
title_CM_N='Normalized confusion matrix'): | |
sys.stdout.write("\rAccuracy: %.2f%%" % (100.*correct/count)) | |
sys.stdout.flush() | |
print(", %i/%i corrects" % (correct, count)) | |
np.savetxt(filename,misclassified,fmt="%s") | |
print( classification_report(y_true=true_labels, | |
y_pred=predicted_lables, | |
target_names=class_names)) | |
cm = confusion_matrix(y_true=true_labels, | |
y_pred=predicted_lables) | |
print(cm) | |
print(title) | |
false_positive_rate, true_positive_rate, thresholds = roc_curve(true_labels, predicted_lables) | |
roc_auc = auc(false_positive_rate, true_positive_rate) | |
plt.title('Receiver Operating Characteristic_ROC 1') | |
plt.plot(false_positive_rate, true_positive_rate, 'b', | |
label='AUC = %0.2f'% roc_auc) | |
plt.legend(loc='lower right') | |
plt.plot([0,1],[0,1],'r--') | |
plt.xlim([-0.1,1.2]) | |
plt.ylim([-0.1,1.2]) | |
plt.ylabel('True Positive Rate') | |
plt.xlabel('False Positive Rate') | |
plt.show() | |
# Compute confusion matrix | |
cnf_matrix = cm | |
np.set_printoptions(precision=2) | |
# Plot non-normalized confusion matrix | |
plt.figure() | |
plot_confusion_matrix(cnf_matrix, classes=class_names, | |
title=title_CM) | |
# Plot normalized confusion matrix | |
plt.figure() | |
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True, | |
title=title_CM_N) | |
plt.show() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--proto', help='path to the network prototxt file(deploy)', type=str, required=True) | |
parser.add_argument('--model', help='path to your caffemodel file', type=str, required=True) | |
parser.add_argument('--mean', help='path to the mean file(.binaryproto)', type=str, required=True) | |
#group = parser.add_mutually_exclusive_group(required=True) | |
parser.add_argument('--db_type', help='lmdb or leveldb', type=str, required=True) | |
parser.add_argument('--db_path', help='path to your lmdb/leveldb dataset', type=str, required=True) | |
args = parser.parse_args() | |
# Extract mean from the mean image file | |
mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto() | |
f = open(args.mean, 'rb') | |
mean_blobproto_new.ParseFromString(f.read()) | |
mean_image = caffe.io.blobproto_to_array(mean_blobproto_new) | |
f.close() | |
caffe.set_mode_gpu() | |
#CNN reconstruction and loading the trained weights | |
#print ("args", vars(args)) | |
predicted_lables=[] | |
true_labels = [] | |
misclassified =[] | |
class_names = ['unsafe','safe'] | |
count=0 | |
correct = 0 | |
idx=0 | |
batch=[] | |
plabe_ls=[] | |
batch_size = 50 | |
cropx = 224 | |
cropy = 224 | |
i = 0 | |
multi_crop = False | |
use_caffe_classifier = True # True | |
net = caffe.Classifier(args.proto, args.model, | |
mean = mean_image[0].mean(1).mean(1), | |
image_dims = (256, 256)) | |
net1 = caffe.Net(args.proto, args.model, caffe.TEST) | |
# transformer = caffe.io.Transformer({'data': net1.blobs['data'].data.shape}) | |
# transformer.set_transpose('data', (2,0,1)) | |
# transformer.set_mean('data', mean_image[0].mean(1).mean(1)) | |
# transformer.set_raw_scale('data', 255) | |
# transformer.set_channel_swap('data', (2,1,0)) | |
net1.blobs['data'].reshape(batch_size, 3,224, 224) | |
data_blob_shape = net1.blobs['data'].data.shape | |
#mu = np.load('mean.npy') | |
#mu = np.array([ 104, 117, 123])#imagenet mean | |
#check and see if its lmdb or leveldb | |
if(args.db_type.lower() == 'lmdb'): | |
lmdb_env = lmdb.open(args.db_path) | |
lmdb_txn = lmdb_env.begin() | |
lmdb_cursor = lmdb_txn.cursor() | |
for key, value in lmdb_cursor: | |
count += 1 | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.float32) | |
#key,image,label | |
#buffer n image | |
if(count % 5000 == 0): | |
print('{0} samples processed so far'.format(count)) | |
if(i < batch_size): | |
i+=1 | |
inf= key,image,label | |
batch.append(inf) | |
#print(key) | |
if(i >= batch_size): | |
#process n image | |
ims=[] | |
for x in range(len(batch)): | |
img = batch[x][1] | |
#img has c,h,w shape! its already gone through transpose and channel swap when it was being saved into lmdb! | |
#Method 0: use transpose.preprocess, it takes longer, and also I need to choose the correct transpose, channel swap, etc to get the correct accuracy! #ims.append(transformer.preprocess('data',img.transpose(1,2,0))). remember that when you are reading off lmdb! images are c,h,w! and they need to be h,w,c! | |
#or the resize method inside preprocess would take ages to comlete!(0.6 second for each image! while it should happen in milliseconds! | |
#method I: crop the both the image and mean file | |
#ims.append(img[:,0:224,0:224] - mean_image[0][:,0:224,0:224] ) | |
#Method II : resize the image to the desired size(crop size) | |
#img = caffe.io.resize_image(img.transpose(1,2,0), (224, 224)) | |
#Method III : use center crop just like caffe does in test time | |
if (use_caffe_classifier != True): | |
#center crop | |
c,w,h = img.shape | |
startx = h//2 - cropx//2 | |
starty = w//2 - cropy//2 | |
img = img[:, startx:startx + cropx, starty:starty + cropy] | |
#transpose the image so we can subtract from mean | |
img = img.transpose(1,2,0)#convert to HWC | |
img -= mean_image[0].mean(1).mean(1) | |
#transpose back to the original state | |
img = img.transpose(2,0,1)#convert to CWH | |
ims.append(img) | |
else: | |
ims.append(img.transpose(1,2,0)) #dade ma chw hast pas besorate HWC dar miarim ke classify mikhad | |
if (use_caffe_classifier != True): | |
net1.blobs['data'].data[...] = ims[:] | |
out_1 = net1.forward() | |
plabe_ls = out_1['pred'] | |
else: | |
#in case of using python3 open classify.py goto line 99 and change it like this : predictions = predictions.reshape((len(predictions) // 10, 10, -1)) | |
out_1 = net.predict(np.asarray(ims), oversample=multi_crop) | |
plabe_ls = out_1 | |
plbl = np.asarray(plabe_ls) | |
#print('labels for the current batch: ',plbl.argmax(axis=1)) | |
plbl = plbl.argmax(axis=1) | |
for j in range(len(batch)): | |
if (plbl[j] == batch[j][2]): | |
correct+=1 | |
else: | |
misclassified.append(batch[j][0]) | |
predicted_lables.append(plbl[j]) | |
true_labels.append(batch[j][2]) | |
batch.clear() | |
i = 0 | |
ShowInfo(correct,count, true_labels, predicted_lables, class_names, misclassified, | |
filename='misclassifieds.txt', | |
title='Receiver Operating Characteristic_ROC' ) | |
else:#leveldb | |
import leveldb | |
db = leveldb.LevelDB(args.db_path) | |
for key, value in db.RangeIter(): | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.float32) | |
#key,image,label | |
#buffer n image | |
if(count % 5000 == 0): | |
print('{0} samples processed so far'.format(count)) | |
if(i < batch_size): | |
i+=1 | |
inf= key,image,label | |
batch.append(inf) | |
#print(key) | |
if(i >= batch_size): | |
#process n image | |
ims=[] | |
for x in range(len(batch)): | |
img = batch[x][1] | |
#img has c,h,w shape! its already gone through transpose and channel swap when it was being saved into lmdb! | |
#Method 0: use transpose.preprocess, it takes longer, and also I need to choose the correct transpose, channel swap, etc to get the correct accuracy! #ims.append(transformer.preprocess('data',img.transpose(1,2,0))). remember that when you are reading off lmdb! images are c,h,w! and they need to be h,w,c! | |
#or the resize method inside preprocess would take ages to comlete!(0.6 second for each image! while it should happen in milliseconds! | |
#method I: crop the both the image and mean file | |
#ims.append(img[:,0:224,0:224] - mean_image[0][:,0:224,0:224] ) | |
#Method II : resize the image to the desired size(crop size) | |
#img = caffe.io.resize_image(img.transpose(1,2,0), (224, 224)) | |
#Method III : use center crop just like caffe does in test time | |
if (use_caffe_classifier != True): | |
#center crop | |
c,w,h = img.shape | |
startx = h//2 - cropx//2 | |
starty = w//2 - cropy//2 | |
img = img[:, startx:startx + cropx, starty:starty + cropy] | |
#transpose the image so we can subtract from mean | |
img = img.transpose(1,2,0)#convert to HWC | |
img -= mean_image[0].mean(1).mean(1) | |
#transpose back to the original state | |
img = img.transpose(2,0,1)#convert to CWH | |
ims.append(img) | |
else: | |
ims.append(img.transpose(1,2,0)) #dade ma chw hast pas besorate HWC dar miarim ke classify mikhad | |
if (use_caffe_classifier != True): | |
net1.blobs['data'].data[...] = ims[:] | |
out_1 = net1.forward() | |
plabe_ls = out_1['pred'] | |
else: | |
#agar python3 hast bayad to classify.py method predict, khate 99 intor bashe : predictions = predictions.reshape((len(predictions) // 10, 10, -1)) | |
out_1 = net.predict(np.asarray(ims), oversample=multi_crop) | |
plabe_ls = out_1 | |
plbl = np.asarray(plabe_ls) | |
#print('labels for the current batch: ',plbl.argmax(axis=1)) | |
plbl = plbl.argmax(axis=1) | |
for j in range(len(batch)): | |
if (plbl[j] == batch[j][2]): | |
correct+=1 | |
else: | |
misclassified.append(batch[j][0]) | |
predicted_lables.append(plbl[j]) | |
true_labels.append(batch[j][2]) | |
batch.clear() | |
i = 0 | |
ShowInfo(correct,count, true_labels, predicted_lables, class_names, misclassified, | |
filename='misclassifieds.txt', | |
title='Receiver Operating Characteristic_ROC' ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi. I´m trying to use the script because I need to plot confusion matrix but I am facing an issue.
I am using dummy.zip dataset with tif files, and Caffe with NO CPU.
Everything with caffe is working fine because I can run ./test_dummy_net.sh and get the acuracity... but when I try to run the script I am facing the error below.
Commnad to call.
python confusionMatrix_Precison_Recall_F1Score_BatchMode.py --proto ~/caffe/dummy/models/lenet_deploy.prototxt --model ~/caffe/dummy/models/lenet_iter_1000.caffemodel --mean ~/caffe/dummy/data/digits/dummy_mean.binaryproto --db_type lmdb --db_path ~/caffe/dummy/data/digits/dummy_train_lmdb
Error
For me, it seems that I have an image that is broken and then when the script attempt to read it and reshape the get an exception.
I am novice, I would appreciate if you could help.
Thanks in advance.