Skip to content

Instantly share code, notes, and snippets.

@KentaKudo
Created February 4, 2018 18:57
Show Gist options
  • Save KentaKudo/0ace4d7cee7d5232cb93d287ebf0e046 to your computer and use it in GitHub Desktop.
Save KentaKudo/0ace4d7cee7d5232cb93d287ebf0e046 to your computer and use it in GitHub Desktop.
import numpy as np
from keras.datasets import cifar10
from keras.utils import to_categorical
from sklearn.model_selection import train_test_split
(train_X, train_y), (test_X, test_y) = cifar10.load_data()
train_X = train_X.astype('float32') / 255
test_X = test_X.astype('float32') / 255
train_y = to_categorical(train_y)
test_y = to_categorical(test_y)
def train(train_X, train_y):
from keras.models import Model
from keras.layers import Dense, Conv2D, MaxPooling2D, Flatten, Input, Dropout, BatchNormalization
from keras.regularizers import l2
# Convolution → Convolution → Pooling layers
def ccp(x, filters, kernel_size, weight_decay, dropout):
x = Conv2D(filters, kernel_size, padding='same', kernel_regularizer=l2(weight_decay), activation='relu')(x)
x = BatchNormalization()(x)
x = Conv2D(filters, kernel_size, padding='same', kernel_regularizer=l2(weight_decay), activation='relu')(x)
x = BatchNormalization()(x)
x = MaxPooling2D(pool_size=(2,2))(x)
x = Dropout(dropout)(x)
return x
inputs = Input(shape=train_X.shape[1:])
x = ccp(inputs, 32, (3,3), 1e-4, 0.2)
x = Flatten()(x)
y = Dense(10, activation='softmax')(x)
m = Model(inputs=inputs, outputs=y)
m.compile(
loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy']
)
h = m.fit(x=train_X, y=train_y, batch_size=32, epochs=20, validation_split=0.1)
return m, h
def plot(h):
"""
Great thanks to: http://parneetk.github.io/blog/cnn-cifar10/
"""
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2, figsize=(15,5))
# accuracy
axs[0].plot(range(1, len(h.history['acc']) + 1), h.history['acc'])
axs[0].plot(range(1, len(h.history['val_acc']) + 1), h.history['val_acc'])
axs[0].set_title('Accuracy')
axs[0].set_ylabel('Accuracy')
axs[0].set_xlabel('Epoch')
axs[0].set_xticks(np.arange(1, len(h.history['acc']) + 1), len(h.history['acc'])/10)
axs[0].legend(['trian', 'val'], loc='best')
# loss
axs[1].plot(range(1, len(h.history['loss']) + 1), h.history['loss'])
axs[1].plot(range(1, len(h.history['val_loss']) + 1), h.history['val_loss'])
axs[1].set_title('Loss')
axs[1].set_ylabel('Loss')
axs[1].set_xlabel('Epoch')
axs[1].set_xticks(np.arange(1, len(h.history['loss']) + 1), len(h.history['loss'])/10)
axs[1].legend(['trian', 'val'], loc='best')
# plt.show()
plt.savefig('result.png')
if __name__ == '__main__':
model, history = train(train_X, train_y)
plot(history)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment