-
-
Save axel-angel/b2af7d980eb217a0af07 to your computer and use it in GitHub Desktop.
#!/usr/bin/python | |
# -*- coding: utf-8 -*- | |
# Author: Axel Angel, copyright 2015, license GPLv3. | |
import sys | |
import caffe | |
import numpy as np | |
import lmdb | |
import argparse | |
from collections import defaultdict | |
def flat_shape(x): | |
"Returns x without singleton dimension, eg: (1,28,28) -> (28,28)" | |
return x.reshape(filter(lambda s: s > 1, x.shape)) | |
def lmdb_reader(fpath): | |
import lmdb | |
lmdb_env = lmdb.open(fpath) | |
lmdb_txn = lmdb_env.begin() | |
lmdb_cursor = lmdb_txn.cursor() | |
for key, value in lmdb_cursor: | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
yield (key, flat_shape(image), label) | |
def leveldb_reader(fpath): | |
import leveldb | |
db = leveldb.LevelDB(fpath) | |
for key, value in db.RangeIter(): | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
yield (key, flat_shape(image), label) | |
def npz_reader(fpath): | |
npz = np.load(fpath) | |
xs = npz['arr_0'] | |
ls = npz['arr_1'] | |
for i, (x, l) in enumerate(np.array([ xs, ls ]).T): | |
yield (i, x, l) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--proto', type=str, required=True) | |
parser.add_argument('--model', type=str, required=True) | |
group = parser.add_mutually_exclusive_group(required=True) | |
group.add_argument('--lmdb', type=str, default=None) | |
group.add_argument('--leveldb', type=str, default=None) | |
group.add_argument('--npz', type=str, default=None) | |
args = parser.parse_args() | |
count = 0 | |
correct = 0 | |
matrix = defaultdict(int) # (real,pred) -> int | |
labels_set = set() | |
net = caffe.Net(args.proto, args.model, caffe.TEST) | |
caffe.set_mode_cpu() | |
print "args", vars(args) | |
if args.lmdb != None: | |
reader = lmdb_reader(args.lmdb) | |
if args.leveldb != None: | |
reader = leveldb_reader(args.leveldb) | |
if args.npz != None: | |
reader = npz_reader(args.npz) | |
for i, image, label in reader: | |
image_caffe = image.reshape(1, *image.shape) | |
out = net.forward_all(data=np.asarray([ image_caffe ])) | |
plabel = int(out['prob'][0].argmax(axis=0)) | |
count += 1 | |
iscorrect = label == plabel | |
correct += (1 if iscorrect else 0) | |
matrix[(label, plabel)] += 1 | |
labels_set.update([label, plabel]) | |
if not iscorrect: | |
print("\rError: i=%s, expected %i but predicted %i" \ | |
% (i, label, plabel)) | |
sys.stdout.write("\rAccuracy: %.1f%%" % (100.*correct/count)) | |
sys.stdout.flush() | |
print(", %i/%i corrects" % (correct, count)) | |
print "" | |
print "Confusion matrix:" | |
print "(r , p) | count" | |
for l in labels_set: | |
for pl in labels_set: | |
print "(%i , %i) | %i" % (l, pl, matrix[(l,pl)]) |
If you implement it, I'll happily update the gist above. Try something along the lines:
import h5py
def hdf5_reader(file_name):
file = h5py.File(file_name, 'r') # open read-only
group_name = file.keys[0] # try to find the first group
group = file[group_name]
for key, value in dict(group).iteritems():
datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(value)
label = int(datum.label)
image = caffe.io.datum_to_array(datum).astype(np.uint8)
yield key, flat_shape(image), label
To cope with encoded images I extended the code above like this:
def getImage(datum):
if datum.encoded:
from cStringIO import StringIO
import PIL
s = StringIO(datum.data)
image = np.array(PIL.Image.open(s))
else:
image = caffe.io.datum_to_array(datum).astype(np.uint8)
return image
def lmdb_reader(fpath):
lmdb_env = lmdb.open(fpath)
lmdb_txn = lmdb_env.begin()
lmdb_cursor = lmdb_txn.cursor()
for key, value in lmdb_cursor:
datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(value)
label = int(datum.label)
image = getImage(datum)
yield (key, flat_shape(image), label)
def leveldb_reader(fpath):
import leveldb
db = leveldb.LevelDB(fpath)
for key, value in db.RangeIter():
datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(value)
label = int(datum.label)
image = getImage(datum)
yield (key, flat_shape(image), label)
@axel-angel, Don't you need to subtract the image mean like in the example here: https://github.com/BVLC/caffe/blob/7003d1b8e24416cb5bdb5537a7805cb5a9de2ca1/examples/00-classification.ipynb
Also what about channel swap?
@alex-angel, In addition images would need channel swap as well correct?
I am working on LMDB database. when I am running this code I am getting convnet_test.py: error: argument --proto is required error. Please help.
@monjoybme You need launch like this:
python ../src/convnet_test_lmdb.py --proto lenet.prototxt --model snapshots/lenet_mnist_v3-id_iter_1000.caffemodel --lmdb ../caffe/examples/mnist/mnist_test_lmdb/
accord @axel-angel
Regards
@axel-angle, thanks for your amazing work.
I am using your script and I got stuck at:
I1119 17:07:53.463573 12920 net.cpp:283] Network initialization done.
args{'proto': 'test.prototxt', 'model': 'models/caffenet_age_train_iter_50000.caffemodel', 'lmdb': 'lmdb_full/age_test_lmdb/', 'leveldb': None, 'npz': None}
Traceback (most recent call last):
File "convnet_test.py", line 75, in <module>
for i, image, label in reader:
File "convnet_test.py", line 28, in lmdb_reader
yield (key, flat_shape(image), label)
File "convnet_test.py", line 15, in flat_shape
return x.reshape(filter(lambda s: s > 1, x.shape))
TypeError: expected sequence object with len >= 0 or a single integer
I used python3 to run. Can u suggest me a solution? Thanks.
I've added flat_shape, it's just to remove empty dimensions.