Skip to content

Instantly share code, notes, and snippets.

@tzutalin
Last active August 31, 2015 06:50
Show Gist options
  • Save tzutalin/912d1774d96266c4e76b to your computer and use it in GitHub Desktop.
Save tzutalin/912d1774d96266c4e76b to your computer and use it in GitHub Desktop.
classify_test.py
import numpy as np
import matplotlib.pyplot as plt
import sys
import caffe
import os
import time
def convertbinarytonpy(binaryprotosrc, npytarget):
if os.path.exists(binaryprotosrc):
blob = caffe.proto.caffe_pb2.blobproto()
data = open( binaryprotosrc, 'rb' ).read()
blob.parsefromstring(data)
arr = np.array( caffe.io.blobproto_to_array(blob) )
out = arr[0]
np.save(npytarget, out )
model_file = 'deploy_nin.prototxt'
pretrained = 'nin_imagenet.caffemodel'
image_file = 'cat.jpg'
mean_file = 'imagenet_mean.npy'
img_size = (224, 224)
convertbinarytonpy('imagenet_mean.binaryproto', mean_file)
meanfile=np.load(mean_file).mean(1).mean(1)
caffe.set_mode_cpu()
net = caffe.classifier(model_file, pretrained,
mean=meanfile,
channel_swap=(2,1,0),
raw_scale=255,
image_dims=img_size)
input_image = caffe.io.load_image(image_file)
plt.imshow(input_image)
start_time = time.time()
prediction = net.predict([input_image]) # predict takes any number of images, and formats them for the caffe net automatically
end_time = time.time()
print '---------------------------------------------'
print 'takes ' + str(end_time - start_time) + ' secs'
print 'prediction shape:', prediction[0].shape
print 'predicted class:', prediction[0].argmax()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment