Skip to content

Instantly share code, notes, and snippets.

@debidatta
Last active January 24, 2016 18:49
Show Gist options
  • Save debidatta/473b8e5957203ad7d655 to your computer and use it in GitHub Desktop.
Save debidatta/473b8e5957203ad7d655 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
import sys
import numpy as np
import argparse
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
# Put in appropraite directory so that python can import caffe
import caffe
def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
def flat_shape(x):
"Returns x without singleton dimension, eg: (1,28,28) -> (28,28)"
return x.reshape(filter(lambda s: s > 1, x.shape))
def lmdb_reader(fpath):
import lmdb
lmdb_env = lmdb.open(fpath)
lmdb_txn = lmdb_env.begin()
lmdb_cursor = lmdb_txn.cursor()
for key, value in lmdb_cursor:
datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(value)
label = int(datum.label)
image = caffe.io.datum_to_array(datum).astype(np.uint8)
yield (key, flat_shape(image), label)
def leveldb_reader(fpath):
import leveldb
db = leveldb.LevelDB(fpath)
for key, value in db.RangeIter():
datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(value)
label = int(datum.label)
image = caffe.io.datum_to_array(datum).astype(np.uint8)
yield (key, flat_shape(image), label)
def npz_reader(fpath):
npz = np.load(fpath)
xs = npz['arr_0']
ls = npz['arr_1']
for i, (x, l) in enumerate(np.array([ xs, ls ]).T):
yield (i, x, l)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--proto', type=str, required=True)
parser.add_argument('--model', type=str, required=True)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--lmdb', type=str, default=None)
group.add_argument('--leveldb', type=str, default=None)
group.add_argument('--npz', type=str, default=None)
args = parser.parse_args()
net = caffe.Net(args.proto, args.model, caffe.TEST)
caffe.set_mode_gpu()
print "args", vars(args)
if args.lmdb != None:
reader = lmdb_reader(args.lmdb)
if args.leveldb != None:
reader = leveldb_reader(args.leveldb)
if args.npz != None:
reader = npz_reader(args.npz)
y_true = []
y_pred =[]
for i, image, label in reader:
out = net.forward()
plabel = int(out['prob'][0].argmax(axis=0))
y_true.append(label)
y_pred.append(plabel)
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
np.set_printoptions(precision=2)
plt.figure()
plot_confusion_matrix(cm)
# Normalize the confusion matrix by row (i.e by the number of samples
# in each class)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
plt.figure()
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment