Last active
February 11, 2020 15:54
-
-
Save tldrafael/210b75b252f1a41aa52e505aa0a57c3b to your computer and use it in GitHub Desktop.
Keras SqueezeNet architecture
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://arxiv.org/pdf/1602.07360.pdf | |
# | |
from keras import backend as K | |
from keras.layers import Input, Convolution2D, MaxPooling2D, Activation, concatenate | |
from keras.layers import GlobalAveragePooling2D | |
from keras.models import Model | |
class SqueezeNet: | |
def __init__(self, input_shape, n_classes): | |
self.input_shape = input_shape | |
self.n_classes = n_classes | |
self.net = self.build_nn() | |
@staticmethod | |
def fire_module(x, fire_id, squeeze=16, expand=64): | |
if K.image_data_format() == 'channels_first': | |
channel_axis = 1 | |
else: | |
channel_axis = 3 | |
s1 = Convolution2D(squeeze, (1, 1), padding='valid', name='fire{}/squeeze1x1'.format(fire_id))(x) | |
s1 = Activation('relu', name='fire{}/squeeze1x1_relu'.format(fire_id))(s1) | |
e1 = Convolution2D(expand, (1, 1), padding='valid', name='fire{}/expand1x1'.format(fire_id))(s1) | |
e1 = Activation('relu', name='fire{}/expand1x1_relu'.format(fire_id))(e1) | |
e3 = Convolution2D(expand, (3, 3), padding='same', name='fire{}/expand3x3'.format(fire_id))(s1) | |
e3 = Activation('relu', name='fire{}/expand3x3_relu'.format(fire_id))(e3) | |
return concatenate([e1, e3], axis=channel_axis, name='fire{}/concat'.format(fire_id)) | |
def build_nn(self): | |
img_inputs = Input(shape=self.input_shape) | |
x = Convolution2D(64, (3, 3), strides=(2, 2), padding='valid', name='conv1')(img_inputs) | |
x = Activation('relu', name='relu1')(x) | |
x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool1')(x) | |
x = SqueezeNet.fire_module(x, fire_id=2, squeeze=16, expand=64) | |
x = SqueezeNet.fire_module(x, fire_id=3, squeeze=16, expand=64) | |
x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool3')(x) | |
x = SqueezeNet.fire_module(x, fire_id=4, squeeze=32, expand=128) | |
x = SqueezeNet.fire_module(x, fire_id=5, squeeze=32, expand=128) | |
x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool5')(x) | |
x = SqueezeNet.fire_module(x, fire_id=6, squeeze=48, expand=192) | |
x = SqueezeNet.fire_module(x, fire_id=7, squeeze=48, expand=192) | |
x = SqueezeNet.fire_module(x, fire_id=8, squeeze=64, expand=256) | |
x = SqueezeNet.fire_module(x, fire_id=9, squeeze=64, expand=256) | |
x = Convolution2D(self.n_classes, (1, 1), padding='valid', name='conv10')(x) | |
x = Activation('relu', name='relu10')(x) | |
x = GlobalAveragePooling2D(name='pool10')(x) | |
x = Activation('softmax', name='output')(x) | |
return Model(img_inputs, x, name='squeezenet') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment