Skip to content

Instantly share code, notes, and snippets.

@tldrafael
Last active February 11, 2020 15:54
Show Gist options
  • Save tldrafael/210b75b252f1a41aa52e505aa0a57c3b to your computer and use it in GitHub Desktop.
Save tldrafael/210b75b252f1a41aa52e505aa0a57c3b to your computer and use it in GitHub Desktop.
Keras SqueezeNet architecture
# https://arxiv.org/pdf/1602.07360.pdf
#
from keras import backend as K
from keras.layers import Input, Convolution2D, MaxPooling2D, Activation, concatenate
from keras.layers import GlobalAveragePooling2D
from keras.models import Model
class SqueezeNet:
def __init__(self, input_shape, n_classes):
self.input_shape = input_shape
self.n_classes = n_classes
self.net = self.build_nn()
@staticmethod
def fire_module(x, fire_id, squeeze=16, expand=64):
if K.image_data_format() == 'channels_first':
channel_axis = 1
else:
channel_axis = 3
s1 = Convolution2D(squeeze, (1, 1), padding='valid', name='fire{}/squeeze1x1'.format(fire_id))(x)
s1 = Activation('relu', name='fire{}/squeeze1x1_relu'.format(fire_id))(s1)
e1 = Convolution2D(expand, (1, 1), padding='valid', name='fire{}/expand1x1'.format(fire_id))(s1)
e1 = Activation('relu', name='fire{}/expand1x1_relu'.format(fire_id))(e1)
e3 = Convolution2D(expand, (3, 3), padding='same', name='fire{}/expand3x3'.format(fire_id))(s1)
e3 = Activation('relu', name='fire{}/expand3x3_relu'.format(fire_id))(e3)
return concatenate([e1, e3], axis=channel_axis, name='fire{}/concat'.format(fire_id))
def build_nn(self):
img_inputs = Input(shape=self.input_shape)
x = Convolution2D(64, (3, 3), strides=(2, 2), padding='valid', name='conv1')(img_inputs)
x = Activation('relu', name='relu1')(x)
x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool1')(x)
x = SqueezeNet.fire_module(x, fire_id=2, squeeze=16, expand=64)
x = SqueezeNet.fire_module(x, fire_id=3, squeeze=16, expand=64)
x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool3')(x)
x = SqueezeNet.fire_module(x, fire_id=4, squeeze=32, expand=128)
x = SqueezeNet.fire_module(x, fire_id=5, squeeze=32, expand=128)
x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool5')(x)
x = SqueezeNet.fire_module(x, fire_id=6, squeeze=48, expand=192)
x = SqueezeNet.fire_module(x, fire_id=7, squeeze=48, expand=192)
x = SqueezeNet.fire_module(x, fire_id=8, squeeze=64, expand=256)
x = SqueezeNet.fire_module(x, fire_id=9, squeeze=64, expand=256)
x = Convolution2D(self.n_classes, (1, 1), padding='valid', name='conv10')(x)
x = Activation('relu', name='relu10')(x)
x = GlobalAveragePooling2D(name='pool10')(x)
x = Activation('softmax', name='output')(x)
return Model(img_inputs, x, name='squeezenet')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment