Last active
January 5, 2017 12:56
-
-
Save notwa/a209d70ecff2987e0baa61b00381d8af to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import keras.backend as K | |
assert K.image_dim_ordering() == 'th' | |
import pickle, time | |
import sys | |
import numpy as np | |
from keras.callbacks import LearningRateScheduler | |
from keras.datasets import mnist | |
from keras.layers import BatchNormalization | |
from keras.layers import Convolution2D, MaxPooling2D | |
from keras.layers import Flatten, Reshape | |
from keras.layers import Input, merge, Dense, Activation | |
from keras.models import Model | |
from keras.utils.np_utils import to_categorical | |
nb_classes = 10 | |
width = 28 | |
height = 28 | |
loss='categorical_crossentropy' | |
name = 'resnet-{:.0f}'.format(time.time()) | |
args = dict(enumerate(sys.argv)) | |
restore_fn = args.get(1) | |
if restore_fn == '.': # TODO: accept any directory | |
# just use most recent resnet-*.pkl file in directory | |
import os | |
is_valid = lambda fn: fn.startswith('resnet-') and fn.endswith('.pkl') | |
files = sorted([fn for fn in os.listdir(restore_fn) if is_valid(fn)]) | |
if len(files) == 0: | |
raise Exception("couldn't find any appropriate .pkl files in the CWD") | |
restore_fn = files[-1] | |
dont_train = False | |
verbose_summary = False | |
reslayers = 4 | |
size = 8 | |
batch_size = 128 | |
epochs = 24 | |
convolutional = True | |
resnet_enabled = True | |
original_resnet = False | |
LR = 1e-2 | |
LRprod = 0.1**(1/20.) # will use a tenth of the learning rate after 20 epochs | |
use_image_generator = True | |
def prepare(X, y): | |
X = X.reshape(X.shape[0], 1, width, height).astype('float32') / 255 | |
# convert class vectors to binary class matrices | |
Y = to_categorical(y, nb_classes) | |
return X, Y | |
# the data, shuffled and split between train and test sets | |
(X_train, y_train), (X_test, y_test) = mnist.load_data() | |
X_train, Y_train = prepare(X_train, y_train) | |
X_test, Y_test = prepare(X_test, y_test) | |
if use_image_generator: | |
from keras.preprocessing.image import ImageDataGenerator | |
idg = ImageDataGenerator(rotation_range=5., | |
width_shift_range=.10, | |
height_shift_range=.10, | |
shear_range=5 / 180 * np.pi, | |
zoom_range=0.1, | |
fill_mode='constant', | |
cval=0.) | |
# ReLU activation is supposed to be the best with he_normal | |
if convolutional: | |
layer = lambda x: Convolution2D(x, 3, 3, init='he_normal', border_mode='same') | |
else: | |
layer = lambda x: Dense(x, init='he_normal') | |
# start construting the model | |
x = Input(shape=(1, width, height)) | |
y = x | |
if convolutional: | |
# it might be worth trying other sizes here | |
y = Convolution2D(size, 7, 7, subsample=(2, 2), border_mode='same')(y) | |
y = MaxPooling2D()(y) | |
else: | |
y = Flatten()(y) | |
y = Dense(dense_size)(y) | |
for i in range(reslayers): | |
skip = y | |
if original_resnet: | |
y = layer(size)(y) | |
y = BatchNormalization(axis=1)(y) | |
y = Activation('relu')(y) | |
y = layer(size)(y) | |
y = BatchNormalization(axis=1)(y) | |
if resnet_enabled: y = merge([skip, y], mode='sum') | |
y = Activation('relu')(y) | |
else: | |
y = BatchNormalization(axis=1)(y) | |
y = Activation('relu')(y) | |
y = layer(size)(y) | |
y = BatchNormalization(axis=1)(y) | |
y = Activation('relu')(y) | |
y = layer(size)(y) | |
if resnet_enabled: y = merge([skip, y], mode='sum') | |
if convolutional: | |
from keras.layers import AveragePooling1D | |
y = Reshape((size, int(width * height / 2**2 / 2**2)))(y) | |
y = AveragePooling1D(size)(y) | |
y = Flatten()(y) | |
y = Dense(nb_classes)(y) | |
y = Activation('softmax')(y) | |
model = Model(input=x, output=y) | |
if verbose_summary: | |
model.summary() | |
else: | |
total_params = 0 | |
for layer in model.layers: | |
total_params += layer.count_params() | |
print("Total params: {}".format(total_params)) | |
if restore_fn: | |
with open(restore_fn, 'rb') as f: | |
W = pickle.loads(f.read()) | |
if not dont_train: | |
# sparsify an existing model | |
for i, w in enumerate(W): | |
if w.shape == (size, size, 3, 3): | |
middle = np.median(np.abs(w.flat)) | |
where = np.abs(w) < middle | |
total = np.prod(w.shape) | |
fmt = 'W[{}]: zeroing {} params of {}' | |
print(fmt.format(i, int(np.count_nonzero(where)), int(total))) | |
W[i] = np.where(where, 0, w) | |
model.set_weights(W) | |
LR /= 10 | |
model.compile(loss=loss, optimizer='adam', metrics=['accuracy']) | |
if not dont_train: | |
callbacks = [LearningRateScheduler(lambda e: LR * LRprod**e)] | |
kwargs = dict( | |
nb_epoch=epochs, | |
validation_data=(X_test, Y_test), | |
callbacks=callbacks, | |
verbose=1 | |
) | |
if use_image_generator: | |
history = model.fit_generator(idg.flow(X_train, Y_train, batch_size=batch_size), | |
samples_per_epoch=len(X_train), **kwargs) | |
else: | |
history = model.fit(X_train, Y_train, batch_size=batch_size, | |
**kwargs) | |
def evaluate(X, Y): | |
score = model.evaluate(X, Y, verbose=0) | |
for name, score in zip(model.metrics_names, score): | |
if name == "acc": | |
print("{:7} {:6.2f}%".format(name, score * 100)) | |
else: | |
print("{:7} {:7.5f}".format(name, score)) | |
print('TRAIN') | |
evaluate(X_train, Y_train) | |
print('TEST') | |
evaluate(X_test, Y_test) | |
print('ALL') | |
evaluate(np.vstack((X_train, X_test)), np.vstack((Y_train, Y_test))) | |
if not dont_train: | |
open(name+'.json', 'w').write(model.to_json()) | |
with open(name+'.pkl', 'wb') as f: | |
f.write(pickle.dumps(model.get_weights())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment