Skip to content

Instantly share code, notes, and snippets.

@rubenfiszel
Created June 17, 2017 03:04
Show Gist options
  • Save rubenfiszel/ebb3e9979a75163852ab95a2aea53c54 to your computer and use it in GitHub Desktop.
Save rubenfiszel/ebb3e9979a75163852ab95a2aea53c54 to your computer and use it in GitHub Desktop.
from keras import layers
from keras.layers import Activation
from keras import models
def residual_network(x):
def resnet_conv(channels, kernel_size, strides, y):
pad_total = kernel_size[0] - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
padding = ((0, 0), (pad_beg, pad_end))
padded = layers.convolutional.ZeroPadding2D(padding=padding)(y)
return layers.Conv2D(channels, kernel_size=kernel_size, strides=strides, padding='valid')(padded)
def bottleneck(y, nb_channels_out, nb_channels_in, _strides=(1, 1), _project_shortcut=False):
shortcut = y
y = layers.Conv2D(nb_channels_in, kernel_size=(1, 1), strides=(1, 1), padding='same')(y)
y = Activation('relu')(y)
if _strides == (1, 1):
y = layers.Conv2D(nb_channels_in, kernel_size=(3, 3), strides=_strides, padding="same")(y)
else:
y = resnet_conv(nb_channels_in, (3, 3), _strides, y)
y = Activation('relu')(y)
y = layers.Conv2D(nb_channels_out, kernel_size=(1, 1), strides=(1, 1), padding='same')(y)
if _project_shortcut or _strides != (1, 1):
shortcut = layers.Conv2D(nb_channels_out, kernel_size=(1, 1), strides=_strides, padding='same')(shortcut)
y = layers.add([shortcut, y])
y = Activation('relu')(y)
return y
x = resnet_conv(64, (7, 7), (2, 2), x)
x = Activation('relu')(x)
x = layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
for i in range(3):
x = bottleneck(x, 256, 64, _project_shortcut=(i == 0))
x = bottleneck(x, 512, 128, _strides=(2, 2))
for i in range(4):
strides = (2, 2) if i == 0 else (1, 1)
x = bottleneck(x, 512, 128, _strides=strides)
for i in range(6):
strides = (2, 2) if i == 0 else (1, 1)
x = bottleneck(x, 1024, 256, _strides=strides)
for i in range(3):
strides = (2, 2) if i == 0 else (1, 1)
x = bottleneck(x, 2048, 512, _strides=strides)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(10)(x)
return x
image_tensor = layers.Input(shape=(32, 32, 3))
network_output = residual_network(image_tensor)
model = models.Model(inputs=[image_tensor], outputs=[network_output])
print(model.summary())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment