Last active
June 8, 2021 00:48
-
-
Save swghosh/f728fbba5a26af93a5f58a6db979e33e to your computer and use it in GitHub Desktop.
TensorFlow keras implementation of GoogLeNet incarnation of the Inception network architecture. (Szegedy et. al "Going Deeper with Convolutions", CVPR 2015) https://ai.google/research/pubs/pub43022
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Construct GoogLeNet incarnation of the Inception network using Keras. | |
""" | |
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D | |
from tensorflow.keras.layers import Input, Dense, Dropout, Concatenate, Flatten | |
from tensorflow.nn import local_response_normalization | |
from tensorflow.keras.regularizers import l2 | |
from tensorflow.keras.utils import plot_model | |
from tensorflow.keras.layers import Layer | |
from tensorflow.keras.models import Model | |
class LRN(Layer): | |
def __init__(self, alpha=0.0001, k=1, beta=0.75, n=5, **kwargs): | |
self.alpha = alpha | |
self.k = k | |
self.beta = beta | |
self.n = n | |
super().__init__(**kwargs) | |
def call(self, x, mask=None): | |
x = local_response_normalization(x, self.n, self.k, self.alpha, self.beta) | |
return x | |
def get_config(self): | |
config = { | |
"alpha": self.alpha, | |
"k": self.k, | |
"beta": self.beta, | |
"n": self.n | |
} | |
base_config = super().get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def inception_module(input_tensor, n_filters, name='inception'): | |
n1, n2, n3, n4, n5, n6 = n_filters | |
tower1 = Conv2D(n1, (1, 1), 1, padding='same', activation='relu', kernel_regularizer=l2(0.0002), name=name + '/1x1')(input_tensor) | |
tower2 = Conv2D(n2, (1, 1), 1, padding='same', activation='relu', kernel_regularizer=l2(0.0002), name=name + '/3x3_reduce')(input_tensor) | |
tower2 = Conv2D(n3, (3, 3), 1, padding='same', activation='relu', kernel_regularizer=l2(0.0002), name=name + '/3x3')(tower2) | |
tower3 = Conv2D(n4, (1, 1), 1, padding='same', activation='relu', kernel_regularizer=l2(0.0002), name=name + '/5x5_reduce')(input_tensor) | |
tower3 = Conv2D(n5, (5, 5), 1, padding='same', activation='relu', kernel_regularizer=l2(0.0002), name=name + '/5x5')(tower3) | |
tower4 = MaxPooling2D((3, 3), 1, padding='same', name=name + '/pool')(input_tensor) | |
tower4 = Conv2D(n6, (1, 1), 1, padding='same', activation='relu', kernel_regularizer=l2(0.0002), name=name + '/pool_proj')(tower4) | |
conc = Concatenate(name=name + '/conc')([tower1, tower2, tower3, tower4]) | |
return conc | |
def auxillary_classifier(input_tensor, num_classes, name='aux'): | |
aux = AveragePooling2D((4, 4), 3, name=name + '/avg_pool')(input_tensor) | |
aux = Conv2D(128, (1, 1), 1, padding='same', activation='relu', kernel_regularizer=l2(0.0002), name=name + '/1x1')(aux) | |
aux = Flatten(name=name + '/flatten')(aux) | |
aux = Dense(1024, activation='relu', kernel_regularizer=l2(0.0002), name=name + '/fc1')(aux) | |
aux = Dropout(0.7, name=name + '/dropout')(aux) | |
aux = Dense(num_classes, activation='softmax', kernel_regularizer=l2(0.0002), name=name + '/fc2')(aux) | |
return aux | |
num_classes = 1000 | |
inp = Input((224, 224, 3), name='input') | |
out = Conv2D(64, (7, 7), 2, padding='same', activation='relu', kernel_regularizer=l2(0.0002), name='conv1')(inp) | |
out = MaxPooling2D((3, 3), 2, padding='same', name='pool1')(out) | |
out = LRN(name='lrn1')(out) | |
out = Conv2D(192, (3, 3), 1, padding='same', activation='relu', kernel_regularizer=l2(0.0002), name='conv2')(out) | |
out = MaxPooling2D((3, 3), 2, padding='same', name='pool2')(out) | |
out = LRN(name='lrn2')(out) | |
out = inception_module(out, (64, 96, 128, 16, 32, 32), name='inception_3a') | |
out = inception_module(out, (128, 128, 192, 32, 96, 64), name='inception_3b') | |
out = MaxPooling2D((3, 3), 2, padding='same', name='pool3')(out) | |
aux1 = out = inception_module(out, (192, 96, 208, 16, 48, 64), name='inception_4a') | |
out = inception_module(out, (160, 112, 224, 24, 64, 64), name='inception_4b') | |
out = inception_module(out, (128, 128, 256, 24, 64, 64), name='inception_4c') | |
aux2 = out = inception_module(out, (112, 144, 288, 32, 64, 64), name='inception_4d') | |
out = inception_module(out, (256, 160, 320, 32, 128, 128), name='inception_4e') | |
out = MaxPooling2D((3, 3), 2, padding='same', name='pool4')(out) | |
out = inception_module(out, (256, 160, 320, 32, 128, 128), name='inception_5a') | |
out = inception_module(out, (384, 192, 384, 48, 128, 128), name='inception_5b') | |
out = GlobalAveragePooling2D(name='avg_pool')(out) | |
out = Dropout(0.4, name='dropout')(out) | |
out = Dense(num_classes, activation='softmax', kernel_regularizer=l2(0.0002), name='fc')(out) | |
aux1 = auxillary_classifier(aux1, num_classes, 'aux_1') | |
aux2 = auxillary_classifier(aux2, num_classes, 'aux_2') | |
model = Model(inp, [aux1, aux2, out], name='GoogLeNet') | |
model.summary() | |
plot_model(model, 'model.png', show_shapes=True) | |
""" | |
Construct the loss function to allow training of auxillary classifiers also. Compile the model accordingly. | |
""" | |
from tensorflow.keras.losses import categorical_crossentropy | |
from tensorflow.keras.optimizers import SGD | |
from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping, ModelCheckpoint, TensorBoard | |
losses = { | |
"fc": categorical_crossentropy, | |
"aux_1/fc2": categorical_crossentropy, | |
"aux_2/fc2": categorical_crossentropy | |
} | |
loss_weights = { | |
"fc": 1, | |
"aux_1/fc2": 0.3, | |
"aux_2/fc2": 0.3 | |
} | |
learn_rate = 0.1 | |
sgd = SGD(lr=learn_rate, momentum=0.9) | |
model.compile(optimizer=sgd, loss=losses, loss_weights=loss_weights, metrics=['accuracy']) | |
def lr_schedule(epoch): | |
return learn_rate * ((1 - 0.04) ** (epoch // 8)) | |
callbacks = [ | |
LearningRateScheduler(schedule=lr_schedule, verbose=1), | |
ModelCheckpoint('weights.{epoch:03d}-{val_loss:.2f}.h5', monitor='val_loss', verbose=1, | |
save_best_only=True, save_weights_only=True), | |
TensorBoard('./tensorboard') | |
] | |
""" | |
Construct a data pipeline that can feed class labels to the auxillary classifiers as well. | |
""" | |
from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
train_path = './imagenet/train' | |
val_path = './imagenet/val' | |
batch_size = 128 | |
image_size = (224, 224) | |
epochs = 200 | |
def preprocess(x): | |
return x / 255.0 | |
def image_gen(train_path, val_path): | |
gen = ImageDataGenerator(preprocessing_function=preprocess) | |
train, val = [gen.flow_from_directory(path, batch_size=batch_size, target_size=image_size) | |
for path in (train_path, val_path)] | |
return train, val | |
def three_way(gen): | |
for x, y in gen: | |
yield x, [y, y, y] | |
train, val = image_gen(train_path, val_path) | |
train_steps, val_steps = len(train), len(val) | |
train, val = three_way(train), three_way(val) | |
""" | |
Train the model. | |
""" | |
train_history = model.fit_generator(train, steps_per_epoch=train_steps, epochs=epochs, | |
callbacks=callbacks, validation_data=val, validation_steps=val_steps) | |
""" | |
Visualise the training. | |
""" | |
from matplotlib import pyplot as plt | |
acc = train_history.history['fc_acc'] | |
val_acc = train_history.history['val_fc_acc'] | |
loss = train_history.history['loss'] | |
val_loss = train_history.history['val_loss'] | |
plt.figure(figsize=(8, 8)) | |
plt.subplot(2, 1, 1) | |
plt.plot(acc) | |
plt.plot(val_acc) | |
plt.legend(['Training Accuracy (fc)', 'Validation Accuracy (fc)'], loc='lower right') | |
plt.ylabel('Accuracy') | |
plt.title('Accuracy') | |
plt.subplot(2, 1, 2) | |
plt.plot(loss) | |
plt.plot(val_loss) | |
plt.legend(['Training Loss (total)', 'Validation Loss (total)'], loc='upper right') | |
plt.ylabel('Cross Entropy') | |
plt.xlabel('Epochs') | |
plt.title('Loss') | |
plt.savefig('epcoch_wise_loss_acc.png') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Can you know the tree of the dataset?