Skip to content

Instantly share code, notes, and snippets.

@KentaKudo
Created February 5, 2018 18:12
Show Gist options
  • Save KentaKudo/c7b2601718888739f969af02e2733f17 to your computer and use it in GitHub Desktop.
Save KentaKudo/c7b2601718888739f969af02e2733f17 to your computer and use it in GitHub Desktop.
def res_net(shape):
from keras.models import Model
from keras.layers import Dense, Conv2D, MaxPooling2D, Flatten, Input, BatchNormalization, Add, Activation, GlobalAveragePooling2D
from keras.regularizers import l2
def resblock(filters, kernel_size=(3, 3), increase=False):
strides = (2, 2) if increase else (1, 1)
def _res_block(x):
x_ = Conv2D(filters, kernel_size,
strides=strides,
padding='same',
kernel_regularizer=l2(weight_decay),
activation='relu')(x)
x_ = BatchNormalization()(x_)
x_ = Conv2D(filters, kernel_size,
strides=(1, 1),
padding='same',
kernel_regularizer=l2(weight_decay),
activation='relu')(x_)
if increase:
x = Conv2D(filters, (1, 1),
strides=(2, 2),
padding='same',
kernel_regularizer=l2(weight_decay),
activation='relu')(x)
x = Add()([x_, x])
x = BatchNormalization()(x)
x = Activation('relu')(x)
return x
return _res_block
weight_decay = 1e-4
inputs = Input(shape=shape)
# 32 * 32 * 3
x = Conv2D(64, (7, 7), padding='same', kernel_regularizer=l2(weight_decay), activation='relu')(inputs)
x = BatchNormalization()(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
# 16 * 16 * 64 → 16 * 16 * 64
x = resblock(64, increase=True)(x)
x = resblock(64)(x)
x = resblock(64)(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
# 8 * 8 * 64 → 8 * 8 * 128
x = resblock(128, increase=True)(x)
x = resblock(128)(x)
x = resblock(128)(x)
x = resblock(128)(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
# 4 * 4 * 128 → 4 * 4 * 256
x = resblock(256, increase=True)(x)
x = resblock(256)(x)
x = resblock(256)(x)
x = resblock(256)(x)
x = resblock(256)(x)
x = resblock(256)(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
# 2 * 2 * 256 → 2 * 2 * 512
x = resblock(512, increase=True)(x)
x = resblock(512)(x)
x = resblock(512)(x)
x = GlobalAveragePooling2D()(x)
# 512 → 10
y = Dense(10, activation='softmax')(x)
model = Model(inputs=inputs, outputs=y)
model.compile(
loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy']
)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment