Created
December 18, 2015 22:54
-
-
Save asanakoy/ba70c3eeff6da26d68d4 to your computer and use it in GitHub Desktop.
Script to show non-deterministic caffe behaviour
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from caffeext import * | |
import caffe | |
import os.path | |
from scipy import misc | |
import time | |
def run_test_on_images(net, transformer, data, batch_size): | |
data_blob_shape = net.blobs['data'].data.shape | |
data_blob_shape = list(data_blob_shape) | |
net.blobs['data'].reshape(batch_size, data_blob_shape[1], data_blob_shape[2], data_blob_shape[3]) | |
k = 0 | |
right_answers = np.array([], dtype=int) | |
images = [] | |
labels = np.array([], dtype=int) | |
# TEST | |
results = [] | |
img_pathes = [] | |
prob = [] | |
for img_path, label in data: | |
image = misc.imread(img_path) # TEST | |
img_pathes.append(img_path) | |
images.append(image) | |
labels = np.append(labels, label) | |
# print k | |
k += 1 | |
if k % batch_size == 0: | |
net.blobs['data'].data[...] = map(lambda x: transformer.preprocess('data', x), images) | |
# process the data through network | |
out = net.forward() | |
predict = np.argmax(out['prob'], axis=1) | |
prob.extend(out['prob']) | |
results.extend(zip(img_pathes, predict)) | |
right_answers = np.append(right_answers, (predict == labels)) | |
images = [] | |
img_pathes = [] | |
labels = np.array([]) | |
if k % batch_size: | |
net.blobs['data'].reshape(k % batch_size, data_blob_shape[1], data_blob_shape[2], data_blob_shape[3]) | |
net.blobs['data'].data[...] = map(lambda x: transformer.preprocess('data', x), images) | |
out = net.forward() | |
predict = np.argmax(out['prob'], axis=1) | |
prob.extend(out['prob']) | |
results.extend(zip(img_pathes, predict)) | |
right_answers = np.append(right_answers, (predict == labels)) | |
accuracy = np.sum(right_answers) / (1.0 * len(right_answers)) | |
# print 'Right/total: {}/{}'.format(np.sum(right_answers), (1.0 * len(right_answers))) | |
# print out['prob'][0][0] | |
return accuracy, results, prob | |
######################################################################################################################## | |
def test_network(network_root_path, snapshot_iteration, sample): | |
""" | |
:param network_root_path: network root folder | |
:param snapshot_iteration: iteration to get snapshot from | |
:param test_type: type of test procedure | |
""" | |
network_root_path = os.path.expanduser(network_root_path) | |
net = caffe.Net(os.path.join(network_root_path, "model/net_config/deploy.prototxt"), | |
os.path.join(network_root_path, "model/snap_iter_{}.caffemodel".format(snapshot_iteration)), caffe.TEST) | |
caffe.set_mode_gpu() | |
# caffe.set_device(0) | |
mean_path = os.path.join(network_root_path, "train.leveldb/mean.binaryproto") | |
mean = protomean2array(mean_path) | |
# transformer transforms image from RGB HxWxC -> BGR CxHxW and subtracts the mean | |
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) | |
transformer.set_transpose('data', (2, 0, 1)) # height*width*channel -> channel*height*width | |
transformer.set_mean('data', mean) # subtract mean | |
transformer.set_raw_scale('data', 1) # pixel value scaling | |
transformer.set_channel_swap('data', (2, 1, 0)) # RGB -> BGR | |
# acc, res_leveldb, prob_ref = run_test_on_images(net, transformer, [sample, sample, sample], 2) | |
# for j in xrange(len(prob_ref)): | |
# print prob_ref[j][0:4] | |
# print acc | |
for i in xrange(3): | |
accuracy, res_leveldb, probabilites = run_test_on_images(net, transformer, [sample, sample, sample], batch_size=i+1) | |
# print np.all(prob[0] == prob_ref[0]) | |
for j in xrange(len(probabilites)): | |
print probabilites[j][0:4] | |
print 'Accuracy:', accuracy | |
print '--' | |
def main(): | |
network_root_path = '~/workspace/meisterwerke/cnn/51_100_test' | |
snapshot_iteration = 210000 | |
sample = ('/export/home/asanakoy/workspace/meisterwerke/crops_227x227_step30/mwm35259_0_flipped.png', 0) | |
print 'Testing network {} on {} iter'.format(network_root_path, snapshot_iteration) | |
start_time = time.clock() | |
test_network(network_root_path, snapshot_iteration, sample) | |
print 'Elapsed time: {} s'.format(time.clock() - start_time) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment