Skip to content

Instantly share code, notes, and snippets.

@dipanjanS
Created August 21, 2019 15:43
Show Gist options
  • Select an option

  • Save dipanjanS/ea70cbe41c22c0e4c6b4416855f7d1a9 to your computer and use it in GitHub Desktop.

Select an option

Save dipanjanS/ea70cbe41c22c0e4c6b4416855f7d1a9 to your computer and use it in GitHub Desktop.
INPUT_SHAPE = (192, 192, 3)
# load pre-trained resnet model
resnet = keras.applications.resnet50.ResNet50(include_top=False, weights='imagenet',
input_shape=INPUT_SHAPE)
# set all layers to be trainable
resnet.trainable = True
for layer in resnet.layers:
resnet.trainable = True
# add dense and output layers
base_resnet = resnet
base_out = base_resnet.output
pool_out = keras.layers.Flatten()(base_out)
hidden1 = keras.layers.Dense(1024, activation='relu')(pool_out)
drop1 = keras.layers.Dropout(rate=0.2)(hidden1)
hidden2 = keras.layers.Dense(512, activation='relu')(drop1)
drop2 = keras.layers.Dropout(rate=0.2)(hidden2)
out = keras.layers.Dense(7, activation='softmax')(drop2)
model = keras.Model(inputs=base_resnet.input, outputs=out)
model.compile(optimizer=keras.optimizers.RMSprop(lr=1e-6),
loss='categorical_crossentropy',
metrics=[categorical_accuracy])
model.summary()
# Output
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_3 (InputLayer) (None, 192, 192, 3) 0
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D) (None, 198, 198, 3) 0 input_3[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D) (None, 96, 96, 64) 9472 conv1_pad[0][0]
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization) (None, 96, 96, 64) 256 conv1[0][0]
__________________________________________________________________________________________________
...
...
__________________________________________________________________________________________________
dense_8 (Dense) (None, 512) 524800 dropout_5[0][0]
__________________________________________________________________________________________________
dropout_6 (Dropout) (None, 512) 0 dense_8[0][0]
__________________________________________________________________________________________________
dense_9 (Dense) (None, 7) 3591 dropout_6[0][0]
==================================================================================================
Total params: 99,614,599
Trainable params: 99,561,479
Non-trainable params: 53,120
__________________________________________________________________________________________________
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment