Last active
March 19, 2018 01:09
-
-
Save standarderror/43582e9a15038806da8a846903438ebe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Based on the tflearn CIFAR-10 example at: | |
https://github.com/tflearn/tflearn/blob/master/examples/images/convnet_cifar10.py | |
""" | |
from __future__ import division, print_function, absolute_import | |
from skimage import color, io | |
from scipy.misc import imresize | |
import numpy as np | |
from sklearn.cross_validation import train_test_split | |
import os | |
from glob import glob | |
import tflearn | |
from tflearn.data_utils import shuffle, to_categorical | |
from tflearn.layers.core import input_data, dropout, fully_connected | |
from tflearn.layers.conv import conv_2d, max_pool_2d | |
from tflearn.layers.estimator import regression | |
from tflearn.data_preprocessing import ImagePreprocessing | |
from tflearn.data_augmentation import ImageAugmentation | |
from tflearn.metrics import Accuracy | |
################################### | |
### Import picture files | |
################################### | |
files_path = '/rawdata/train/' | |
cat_files_path = os.path.join(files_path, 'cat*.jpg') | |
dog_files_path = os.path.join(files_path, 'dog*.jpg') | |
cat_files = sorted(glob(cat_files_path)) | |
dog_files = sorted(glob(dog_files_path)) | |
n_files = len(cat_files) + len(dog_files) | |
print(n_files) | |
size_image = 64 | |
allX = np.zeros((n_files, size_image, size_image, 3), dtype='float64') | |
ally = np.zeros(n_files) | |
count = 0 | |
for f in cat_files: | |
try: | |
img = io.imread(f) | |
new_img = imresize(img, (size_image, size_image, 3)) | |
allX[count] = np.array(new_img) | |
ally[count] = 0 | |
count += 1 | |
except: | |
continue | |
for f in dog_files: | |
try: | |
img = io.imread(f) | |
new_img = imresize(img, (size_image, size_image, 3)) | |
allX[count] = np.array(new_img) | |
ally[count] = 1 | |
count += 1 | |
except: | |
continue | |
################################### | |
# Prepare train & test samples | |
################################### | |
# test-train split | |
X, X_test, Y, Y_test = train_test_split(allX, ally, test_size=0.1, random_state=42) | |
# encode the Ys | |
Y = to_categorical(Y, 2) | |
Y_test = to_categorical(Y_test, 2) | |
################################### | |
# Image transformations | |
################################### | |
# normalisation of images | |
img_prep = ImagePreprocessing() | |
img_prep.add_featurewise_zero_center() | |
img_prep.add_featurewise_stdnorm() | |
# Create extra synthetic training data by flipping & rotating images | |
img_aug = ImageAugmentation() | |
img_aug.add_random_flip_leftright() | |
img_aug.add_random_rotation(max_angle=25.) | |
################################### | |
# Define network architecture | |
################################### | |
# Input is a 32x32 image with 3 color channels (red, green and blue) | |
network = input_data(shape=[None, 64, 64, 3], | |
data_preprocessing=img_prep, | |
data_augmentation=img_aug) | |
# 1: Convolution layer with 32 filters, each 3x3x3 | |
conv_1 = conv_2d(network, 32, 3, activation='relu', name='conv_1') | |
# 2: Max pooling layer | |
network = max_pool_2d(conv_1, 2) | |
# 3: Convolution layer with 64 filters | |
conv_2 = conv_2d(network, 64, 3, activation='relu', name='conv_2') | |
# 4: Convolution layer with 64 filters | |
conv_3 = conv_2d(conv_2, 64, 3, activation='relu', name='conv_3') | |
# 5: Max pooling layer | |
network = max_pool_2d(conv_3, 2) | |
# 6: Fully-connected 512 node layer | |
network = fully_connected(network, 512, activation='relu') | |
# 7: Dropout layer to combat overfitting | |
network = dropout(network, 0.5) | |
# 8: Fully-connected layer with two outputs | |
network = fully_connected(network, 2, activation='softmax') | |
# Configure how the network will be trained | |
acc = Accuracy(name="Accuracy") | |
network = regression(network, optimizer='adam', | |
loss='categorical_crossentropy', | |
learning_rate=0.0005, metric=acc) | |
# Wrap the network in a model object | |
model = tflearn.DNN(network, checkpoint_path='model_cat_dog_6.tflearn', max_checkpoints = 3, | |
tensorboard_verbose = 3, tensorboard_dir='tmp/tflearn_logs/') | |
################################### | |
# Train model for 100 epochs | |
################################### | |
model.fit(X, Y, validation_set=(X_test, Y_test), batch_size=500, | |
n_epoch=100, run_id='model_cat_dog_6', show_metric=True) | |
model.save('model_cat_dog_6_final.tflearn') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment