Skip to content

Instantly share code, notes, and snippets.

@BIGBALLON
Created April 17, 2018 09:52
Show Gist options
  • Save BIGBALLON/d26633090c0b11016a4259fcd7231c8c to your computer and use it in GitHub Desktop.
Save BIGBALLON/d26633090c0b11016a4259fcd7231c8c to your computer and use it in GitHub Desktop.
fashion minist with wide resnet
import keras
import numpy as np
import math
from keras.datasets import fashion_mnist
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.normalization import BatchNormalization
from keras.layers import Conv2D, Dense, Input, add, Activation, Flatten, AveragePooling2D
from keras.callbacks import LearningRateScheduler, TensorBoard
from keras.regularizers import l2
from keras import optimizers
from keras.models import Model
DEPTH = 28
WIDE = 10
IN_FILTERS = 16
CLASS_NUM = 10
IMG_ROWS, IMG_COLS = 28, 28
IMG_CHANNELS = 1
BATCH_SIZE = 128
EPOCHS = 200
ITERATIONS = 60000 // BATCH_SIZE + 1
WEIGHT_DECAY = 0.0005
LOG_FILE_PATH = './w_resnet/'
from keras import backend as K
# set GPU memory
if('tensorflow' == K.backend()):
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
def scheduler(epoch):
return 0.1/2.0 + 0.1/2.*math.cos(math.pi/2.*(epoch/(EPOCHS/2.)))
def color_preprocessing(x_train,x_test):
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
mean = 72.94042
std = 90.02121
x_test = x_test.reshape((-1,28,28,1))
x_train = x_train.reshape((-1,28,28,1))
x_test = (x_test - mean) / std
x_train = (x_train - mean) / std
return x_train, x_test
def wide_residual_network(img_input,classes_num,depth,k):
print('Wide-Resnet %dx%d' %(depth, k))
n_filters = [16, 16*k, 32*k, 64*k]
n_stack = (depth - 4) // 6
def conv3x3(x,filters):
return Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1), padding='same',
kernel_initializer='he_normal',
kernel_regularizer=l2(WEIGHT_DECAY),
use_bias=False)(x)
def bn_relu(x):
x = BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
x = Activation('relu')(x)
return x
def residual_block(x,out_filters,increase=False):
global IN_FILTERS
stride = (1,1)
if increase:
stride = (2,2)
o1 = bn_relu(x)
conv_1 = Conv2D(out_filters,
kernel_size=(3,3),strides=stride,padding='same',
kernel_initializer='he_normal',
kernel_regularizer=l2(WEIGHT_DECAY),
use_bias=False)(o1)
o2 = bn_relu(conv_1)
conv_2 = Conv2D(out_filters,
kernel_size=(3,3), strides=(1,1), padding='same',
kernel_initializer='he_normal',
kernel_regularizer=l2(WEIGHT_DECAY),
use_bias=False)(o2)
if increase or IN_FILTERS != out_filters:
proj = Conv2D(out_filters,
kernel_size=(1,1),strides=stride,padding='same',
kernel_initializer='he_normal',
kernel_regularizer=l2(WEIGHT_DECAY),
use_bias=False)(o1)
block = add([conv_2, proj])
else:
block = add([conv_2,x])
return block
def wide_residual_layer(x,out_filters,increase=False):
global IN_FILTERS
x = residual_block(x,out_filters,increase)
IN_FILTERS = out_filters
for _ in range(1,int(n_stack)):
x = residual_block(x,out_filters)
return x
x = conv3x3(img_input,n_filters[0])
x = wide_residual_layer(x,n_filters[1])
x = wide_residual_layer(x,n_filters[2],increase=True)
x = wide_residual_layer(x,n_filters[3],increase=True)
x = BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
x = Activation('relu')(x)
x = AveragePooling2D((7,7))(x)
x = Flatten()(x)
x = Dense(classes_num,
activation='softmax',
kernel_initializer='he_normal',
kernel_regularizer=l2(WEIGHT_DECAY),
use_bias=False)(x)
return x
if __name__ == '__main__':
# load data
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
y_train = keras.utils.to_categorical(y_train, CLASS_NUM)
y_test = keras.utils.to_categorical(y_test, CLASS_NUM)
# color preprocessing
x_train, x_test = color_preprocessing(x_train, x_test)
# build network
img_input = Input(shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS))
output = wide_residual_network(img_input,CLASS_NUM,DEPTH,WIDE)
resnet = Model(img_input, output)
print(resnet.summary())
# set optimizer
sgd = optimizers.SGD(lr=.1, momentum=0.9, nesterov=True)
resnet.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
# set callback
tb_cb = TensorBoard(log_dir=LOG_FILE_PATH, histogram_freq=0)
change_lr = LearningRateScheduler(scheduler)
cbks = [change_lr,tb_cb]
# set data augmentation
print('Using real-time data augmentation.')
datagen = ImageDataGenerator(horizontal_flip=True,
width_shift_range=0.125,height_shift_range=0.125,fill_mode='reflect')
datagen.fit(x_train)
# start training
resnet.fit_generator(datagen.flow(x_train, y_train,batch_size=BATCH_SIZE),
steps_per_epoch=ITERATIONS,
epochs=EPOCHS,
callbacks=cbks,
validation_data=(x_test, y_test))
resnet.save('wresnet.h5')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment