Last active
January 4, 2019 13:26
-
-
Save axel-angel/b2af7d980eb217a0af07 to your computer and use it in GitHub Desktop.
Caffe script to compute accuracy and confusion matrix
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# -*- coding: utf-8 -*- | |
# Author: Axel Angel, copyright 2015, license GPLv3. | |
import sys | |
import caffe | |
import numpy as np | |
import lmdb | |
import argparse | |
from collections import defaultdict | |
def flat_shape(x): | |
"Returns x without singleton dimension, eg: (1,28,28) -> (28,28)" | |
return x.reshape(filter(lambda s: s > 1, x.shape)) | |
def lmdb_reader(fpath): | |
import lmdb | |
lmdb_env = lmdb.open(fpath) | |
lmdb_txn = lmdb_env.begin() | |
lmdb_cursor = lmdb_txn.cursor() | |
for key, value in lmdb_cursor: | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
yield (key, flat_shape(image), label) | |
def leveldb_reader(fpath): | |
import leveldb | |
db = leveldb.LevelDB(fpath) | |
for key, value in db.RangeIter(): | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
yield (key, flat_shape(image), label) | |
def npz_reader(fpath): | |
npz = np.load(fpath) | |
xs = npz['arr_0'] | |
ls = npz['arr_1'] | |
for i, (x, l) in enumerate(np.array([ xs, ls ]).T): | |
yield (i, x, l) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--proto', type=str, required=True) | |
parser.add_argument('--model', type=str, required=True) | |
group = parser.add_mutually_exclusive_group(required=True) | |
group.add_argument('--lmdb', type=str, default=None) | |
group.add_argument('--leveldb', type=str, default=None) | |
group.add_argument('--npz', type=str, default=None) | |
args = parser.parse_args() | |
count = 0 | |
correct = 0 | |
matrix = defaultdict(int) # (real,pred) -> int | |
labels_set = set() | |
net = caffe.Net(args.proto, args.model, caffe.TEST) | |
caffe.set_mode_cpu() | |
print "args", vars(args) | |
if args.lmdb != None: | |
reader = lmdb_reader(args.lmdb) | |
if args.leveldb != None: | |
reader = leveldb_reader(args.leveldb) | |
if args.npz != None: | |
reader = npz_reader(args.npz) | |
for i, image, label in reader: | |
image_caffe = image.reshape(1, *image.shape) | |
out = net.forward_all(data=np.asarray([ image_caffe ])) | |
plabel = int(out['prob'][0].argmax(axis=0)) | |
count += 1 | |
iscorrect = label == plabel | |
correct += (1 if iscorrect else 0) | |
matrix[(label, plabel)] += 1 | |
labels_set.update([label, plabel]) | |
if not iscorrect: | |
print("\rError: i=%s, expected %i but predicted %i" \ | |
% (i, label, plabel)) | |
sys.stdout.write("\rAccuracy: %.1f%%" % (100.*correct/count)) | |
sys.stdout.flush() | |
print(", %i/%i corrects" % (correct, count)) | |
print "" | |
print "Confusion matrix:" | |
print "(r , p) | count" | |
for l in labels_set: | |
for pl in labels_set: | |
print "(%i , %i) | %i" % (l, pl, matrix[(l,pl)]) |
I am working on LMDB database. when I am running this code I am getting convnet_test.py: error: argument --proto is required error. Please help.
@monjoybme You need launch like this:
python ../src/convnet_test_lmdb.py --proto lenet.prototxt --model snapshots/lenet_mnist_v3-id_iter_1000.caffemodel --lmdb ../caffe/examples/mnist/mnist_test_lmdb/
accord @axel-angel
Regards
@axel-angle, thanks for your amazing work.
I am using your script and I got stuck at:
I1119 17:07:53.463573 12920 net.cpp:283] Network initialization done.
args{'proto': 'test.prototxt', 'model': 'models/caffenet_age_train_iter_50000.caffemodel', 'lmdb': 'lmdb_full/age_test_lmdb/', 'leveldb': None, 'npz': None}
Traceback (most recent call last):
File "convnet_test.py", line 75, in <module>
for i, image, label in reader:
File "convnet_test.py", line 28, in lmdb_reader
yield (key, flat_shape(image), label)
File "convnet_test.py", line 15, in flat_shape
return x.reshape(filter(lambda s: s > 1, x.shape))
TypeError: expected sequence object with len >= 0 or a single integer
I used python3 to run. Can u suggest me a solution? Thanks.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@alex-angel, In addition images would need channel swap as well correct?