Created
July 22, 2016 15:37
-
-
Save jcjohnson/564c30b82e4211b917d800a1c34a6a22 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse, os, glob, tempfile | |
import h5py | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.misc import imread, imresize | |
# Stupid workaround for some messed up images | |
from PIL import ImageFile | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
import caffe | |
def write_temp_deploy(source_prototxt, batch_size): | |
""" | |
Modifies an existing prototxt by setting the batch size to a specific value. | |
A modified prototxt file is written as a temporary file. | |
Inputs: | |
- source_prototxt: Path to a deploy.prototxt that will be modified | |
- batch_size: Desired batch size for the network | |
Returns: | |
- path to the temporary file containing the modified prototxt | |
""" | |
_, target = tempfile.mkstemp() | |
with open(source_prototxt, 'r') as f: | |
lines = f.readlines() | |
found_batch_size_line = False | |
with open(target, 'w') as f: | |
for line in lines: | |
if line.startswith('input_dim:') and not found_batch_size_line: | |
found_batch_size_line = True | |
line = 'input_dim: %d\n' % batch_size | |
f.write(line) | |
return target | |
def resize_mean_image(mean_image, height, width): | |
""" | |
Resize the (ImageNet) mean image to a given size. | |
Inputs: | |
- mean_image: numpy float array of shape (3, H, W), in BGR order. | |
This is the format of the mean ImageNet image provided by Caffe. | |
- height, width: Desired height and width | |
Return: | |
A numpy float array of shape (3, height, width) in BGR order. | |
""" | |
mean_image_t = mean_image.transpose(1, 2, 0).astype('uint8') | |
mean_image_t_resized = imresize(mean_image_t, (height, width)) | |
mean_image_resized = mean_image_t_resized.transpose(2, 0, 1).astype('float') | |
return mean_image_resized | |
def load_image(image_filename, height, width, mean_image): | |
""" | |
Read an image off disk and prepare it for caffe. We need to do the following: | |
(1) Resize to (height, width) | |
(2) Swap color channels from RGB to BGR | |
(3) Transpose from (H, W, C) to (C, H, W) | |
(4) Convert from uint8 to float | |
(5) Subtract mean image (which is already BGR) | |
Inputs: | |
- image_filename: Path to the image file to read | |
- height, width: Input size of the network; we'll reshape the image to this size | |
- mean_image: Numpy float array of shape (3, height, width) in BGR format giving | |
mean image to be subtracted. | |
""" | |
img = imread(image_filename) | |
try: | |
img = imresize(img, (height, width)) | |
except ValueError as e: | |
print img.shape, image_filename | |
print 1/0 | |
if img.ndim == 2: | |
# handle grayscale by adding an extra dim and replicating three times | |
img = img[:, :, None][:, :, [0, 0, 0]] | |
img = img[:, :, [2, 1, 0]].transpose(2, 1, 0).astype('float') - mean_image | |
return img | |
if __name__ == '__main__': | |
CAFFENET = '$CAFFE_ROOT/models/bvlc_reference_caffenet' | |
CAFFENET_DEPLOY = os.path.join(CAFFENET, 'deploy.prototxt') | |
CAFFENET_CAFFEMODEL = os.path.join(CAFFENET, 'bvlc_reference_caffenet.caffemodel') | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--image_list', default='', required=True) | |
parser.add_argument('--deploy_txt', default=CAFFENET_DEPLOY) | |
parser.add_argument('--caffemodel', default=CAFFENET_CAFFEMODEL) | |
parser.add_argument('--mean_file', | |
default='$CAFFE_ROOT/python/caffe/imagenet/ilsvrc_2012_mean.npy') | |
parser.add_argument('--vgg_mean', action='store_true') | |
parser.add_argument('--gpu', type=int, default=0) | |
parser.add_argument('--blob_name', default='fc7') | |
parser.add_argument('--batch_size', default=100, type=int) | |
parser.add_argument('--output_h5_file', default='features.h5') | |
args = parser.parse_args() | |
if args.gpu < 0: | |
caffe.set_mode_cpu() | |
else: | |
caffe.set_mode_gpu() | |
caffe.set_device(args.gpu) | |
deploy_file = os.path.expandvars(args.deploy_txt) | |
caffemodel_file = os.path.expandvars(args.caffemodel) | |
temp_deploy = write_temp_deploy(deploy_file, args.batch_size) | |
net = caffe.Net(temp_deploy, caffemodel_file, caffe.TEST) | |
net_height = net.blobs['data'].data.shape[2] | |
net_width = net.blobs['data'].data.shape[3] | |
# Read in image filenames from txt file | |
image_filenames = [] | |
with open(args.image_list, 'r') as f: | |
for line in f: | |
image_filenames.append(line.strip()) | |
print net.blobs[args.blob_name].data.shape | |
mean_image_file = os.path.expandvars(args.mean_file) | |
mean_image = np.load(mean_image_file) | |
# print 'mean image stats:' | |
# print mean_image.shape, mean_image.dtype | |
# print mean_image.min(), mean_image.max() | |
if args.vgg_mean: | |
print 'using vgg mean' | |
# VGG was trained by subtracting the mean pixel, not the mean image. | |
# The mean BGR pixel value is given at | |
# https://gist.github.com/ksimonyan/3785162f95cd2d5fee77 | |
pixel = [103.939, 116.779, 123.68] | |
mean_image = np.asarray(pixel).reshape(3, 1, 1) | |
mean_image_resized = resize_mean_image(mean_image, net_height, net_width) | |
# print 'resized mean image stats:' | |
# print mean_image_resized.shape, mean_image_resized.dtype | |
# print mean_image_resized.min(), mean_image_resized.max() | |
# plt.imshow(mean_image.transpose(1,2,0)[:, :, [2,1,0]].astype('uint8')) | |
# plt.show() | |
# | |
# plt.imshow(mean_image_resized.transpose(1,2,0)[:, :, [2,1,0]].astype('uint8')) | |
# plt.show() | |
num_images = len(image_filenames) | |
h5_f = h5py.File(args.output_h5_file, 'w') | |
feature_shape = (num_images,) + net.blobs[args.blob_name].data.shape[1:] | |
dset = h5_f.create_dataset('features', feature_shape, dtype='f4') | |
dset.attrs['blob_name'] = args.blob_name | |
dset.attrs['deploy_txt'] = deploy_file | |
dset.attrs['caffemodel'] = caffemodel_file | |
dset.attrs['mean_file'] = mean_image_file | |
next_batch_idx = 0 | |
next_dset_idx = 0 | |
batch_data = np.zeros_like(net.blobs['data'].data) | |
for i, image_filename in enumerate(image_filenames): | |
img = load_image(image_filename, net_height, net_width, mean_image_resized) | |
batch_data[next_batch_idx] = img | |
next_batch_idx += 1 | |
if next_batch_idx == args.batch_size: | |
net.forward(data=batch_data) | |
next_batch_idx = 0 | |
dset[next_dset_idx:(next_dset_idx+args.batch_size)] = net.blobs[args.blob_name].data.copy() | |
next_dset_idx += args.batch_size | |
print 'done with %d / %d images' % (i + 1, num_images) | |
if next_batch_idx > 0: | |
net.forward(data=batch_data) | |
dset[next_dset_idx:] = net.blobs[args.blob_name].data[:next_batch_idx].copy() | |
h5_f.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment