Created
November 3, 2016 15:33
-
-
Save jcjohnson/97c9f9d73c66ae87174a14e7ea7198fb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse, os | |
import numpy as np | |
from scipy.misc import imread, imresize | |
from skimage.filters import gaussian | |
import h5py | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--train_dir', default='data/yang-91') | |
parser.add_argument('--val_dir', default='data/set5') | |
parser.add_argument('--max_train', default=-1, type=int) | |
parser.add_argument('--max_val', default=-1, type=int) | |
parser.add_argument('--train_list', default=None) | |
parser.add_argument('--val_list', default=None) | |
parser.add_argument('--output_h5', default='data/yang-91.h5') | |
parser.add_argument('--patch_size', default=96, type=int) | |
parser.add_argument('--patch_stride', default=7, type=int) | |
parser.add_argument('--sizes', default='2,3,4,8,16') | |
parser.add_argument('--sigma', default=1.0, type=float) | |
args = parser.parse_args() | |
def handle_split(split, file_list, h5_file): | |
# This should be easy to fit in memory | |
sizes = [int(x) for x in args.sizes.split(',')] | |
size_tuples = [args.patch_size / s for s in sizes] | |
size_tuples = [(s, s) for s in size_tuples] | |
patches = [] | |
small_patches = {s: [] for s in sizes} | |
# For validation images, use stride = patch size to reduce size | |
stride = args.patch_stride | |
if split == 'val': | |
stride = args.patch_size | |
# Extract patches from all images | |
num_patches = 0 | |
for i, in_path in enumerate(file_list): | |
print 'Starting image %d / %d' % (i + 1, len(file_list)) | |
# in_path = os.path.join(input_dir, filename) | |
img = imread(in_path) | |
if img.ndim == 0: continue | |
if img.ndim == 2: img = img[:, :, None][:, :, [0, 0, 0]] | |
H, W = img.shape[0], img.shape[1] | |
for x0 in xrange(0, W - args.patch_size, stride): | |
x1 = x0 + args.patch_size | |
for y0 in xrange(0, H - args.patch_size, stride): | |
y1 = y0 + args.patch_size | |
patch = img[y0:y1, x0:x1] | |
assert patch.shape == (args.patch_size, args.patch_size, 3), patch.shape | |
patches.append(patch[np.newaxis]) | |
for size, size_tuple in zip(sizes, size_tuples): | |
small_patch = imresize(gaussian(patch, args.sigma), size_tuple) | |
small_patches[size].append(small_patch[np.newaxis]) | |
# Shuffle and concatenate all patches into numpy arrays | |
patches = np.concatenate(patches, axis=0).transpose(0, 3, 1, 2) | |
order = np.random.permutation(patches.shape[0]) | |
patches = patches[order] | |
for k, v in small_patches.iteritems(): | |
small_patches[k] = np.concatenate(v, axis=0).transpose(0, 3, 1, 2) | |
small_patches[k] = small_patches[k][order] | |
# Write patches to an HDF5 file | |
print patches.shape | |
h5_file.create_dataset('%s/y' % split, data=patches) | |
for k, v in small_patches.iteritems(): | |
print v.shape | |
h5_file.create_dataset('%s/x_%d' % (split, k), data=v) | |
def get_file_list(image_dir, image_list, max_files): | |
if image_list is None: | |
file_list = [os.path.join(image_dir, fn) for fn in os.listdir(image_dir)] | |
else: | |
with open(image_list, 'r') as f: | |
file_list = [line.strip() for line in f] | |
if max_files > 0: | |
file_list = file_list[:max_files] | |
return file_list | |
if __name__ == '__main__': | |
with h5py.File(args.output_h5, 'w') as h5_file: | |
val_list = get_file_list(args.val_dir, args.val_list, args.max_val) | |
train_list = get_file_list(args.train_dir, args.train_list, args.max_train) | |
handle_split('val', val_list, h5_file) | |
handle_split('train', train_list, h5_file) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment