Created
February 5, 2017 22:19
-
-
Save crowsonkb/873034df27b2aae5bad6604cffe751e7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from functools import partial, reduce | |
import math | |
from pathlib import Path | |
import random | |
from keras.layers import * | |
from keras.models import Model | |
from keras.utils import io_utils | |
import numpy as np | |
from PIL import Image | |
def seq(*layers): | |
return reduce(lambda f, g: lambda x: g(f(x)), layers, lambda x: x) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--dataset', default='dogs_vs_cats.h5', | |
help='the path to the HDF5 format dataset') | |
args = parser.parse_args() | |
x_train = io_utils.HDF5Matrix(args.dataset, 'x_train') | |
y_train = io_utils.HDF5Matrix(args.dataset, 'y_train') | |
x_val = io_utils.HDF5Matrix(args.dataset, 'x_val') | |
y_val = io_utils.HDF5Matrix(args.dataset, 'y_val') | |
input_layer = Input(shape=x_train.shape[1:]) | |
layers = seq( | |
Conv2D(8, 3, 3, activation='relu', border_mode='same'), | |
Conv2D(8, 3, 3, activation='relu', border_mode='same'), | |
MaxPooling2D((2, 2)), | |
Conv2D(16, 3, 3, activation='relu', border_mode='same'), | |
Conv2D(16, 3, 3, activation='relu', border_mode='same'), | |
MaxPooling2D((2, 2)), | |
Conv2D(32, 3, 3, activation='relu', border_mode='same'), | |
Conv2D(32, 3, 3, activation='relu', border_mode='same'), | |
GlobalAveragePooling2D(), | |
Dense(1, activation='sigmoid'), | |
) | |
model = Model(input_layer, layers(input_layer)) | |
model.summary() | |
model.compile('adam', 'binary_crossentropy', metrics=['accuracy']) | |
print('Building model...') | |
try: | |
model.fit(x_train, y_train, validation_data=(x_val, y_val), shuffle='batch', nb_epoch=100) | |
except KeyboardInterrupt: | |
pass | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import concurrent.futures | |
import os | |
from pathlib import Path | |
import random | |
import threading | |
import h5py | |
import numpy as np | |
from PIL import Image | |
DATASET_PATH = '/Users/kat/Documents/datasets/dogs-vs-cats' | |
IMAGE_SIZE = (32, 32) | |
def extract_label(path): | |
name = Path(path).name | |
if name.startswith('cat'): | |
return 0 | |
if name.startswith('dog'): | |
return 1 | |
raise ValueError(f'Could not determine label for {name}.') | |
def write_image(path, x, y, i): | |
image = Image.open(path).convert('RGB').resize((IMAGE_SIZE[1], IMAGE_SIZE[0]), Image.LANCZOS) | |
x[i] = np.float32(image).transpose((2, 0, 1)) / 255 | |
if y is not None: | |
y[i] = extract_label(path) | |
def write_dataset(paths, x, y=None): | |
done = 0 | |
lock = threading.Lock() | |
def callback(fut): | |
nonlocal done, lock | |
with lock: | |
done += 1 | |
if done % 100 == 99: | |
print(f'{done+1}/{len(paths)} images processed.') | |
futures = [] | |
with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as e: | |
for i, path in enumerate(paths): | |
fut = e.submit(write_image, path, x, y, i) | |
fut.add_done_callback(callback) | |
futures.append(fut) | |
concurrent.futures.wait(futures) | |
def main(): | |
train_paths = list((Path(DATASET_PATH) / 'train').iterdir()) | |
test_paths = list((Path(DATASET_PATH) / 'test1').iterdir()) | |
random.shuffle(train_paths) | |
val_samples = round(len(train_paths) * 0.1) | |
val_paths = train_paths[:val_samples] | |
train_paths = train_paths[val_samples:] | |
with h5py.File('dogs_vs_cats.h5', 'w') as f: | |
x_train = f.create_dataset('x_train', (len(train_paths), 3) + IMAGE_SIZE, np.float32) | |
y_train = f.create_dataset('y_train', (len(train_paths),), np.uint8) | |
x_val = f.create_dataset('x_val', (len(val_paths), 3) + IMAGE_SIZE, np.float32) | |
y_val = f.create_dataset('y_val', (len(val_paths),), np.uint8) | |
x_test = f.create_dataset('x_test', (len(test_paths), 3) + IMAGE_SIZE, np.float32) | |
write_dataset(train_paths, x_train, y_train) | |
write_dataset(val_paths, x_val, y_val) | |
write_dataset(test_paths, x_test) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment