Created
April 27, 2019 00:17
-
-
Save toluwajosh/036c0018668557e225e36cc3c4977f10 to your computer and use it in GitHub Desktop.
An attempt to reuse layers and pretrained weights of models from keras applications
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
An attempt to reuse layers and pretrained weights of models from keras applications | |
The background to this attempt is here: https://www.tensorflow.org/tutorials/images/transfer_learning | |
""" | |
from __future__ import absolute_import, division, print_function | |
import os | |
import tensorflow as tf | |
from tensorflow import keras | |
print("\n\nTensorFlow version is ", tf.__version__) | |
image_size = 160 | |
IMG_SHAPE = (image_size, image_size, 3) | |
# Create the base model from the pre-trained model MobileNet V2 | |
# input_tensor = keras.Input(shape=IMG_SHAPE) | |
base_model = tf.keras.applications.VGG19(input_shape=IMG_SHAPE, | |
# input_tensor=input_tensor, | |
include_top=False, weights='imagenet') | |
base_model.trainable = True | |
# Sequential model approach: | |
seq_model = keras.Sequential([ | |
# keras.Input(shape=IMG_SHAPE), # we dont need this line for sequential model build | |
base_model.layers[0], | |
base_model.layers[1], | |
base_model.layers[2], | |
keras.layers.Conv2D(64,(3,3), activation='relu') | |
]) | |
# print out model summary | |
seq_model.summary() | |
# print(seq_model.layers[0].weights[0][0]) | |
print(seq_model.layers[0].input) | |
print("\n\n") | |
# Functional api approach | |
inputs = keras.Input(shape=IMG_SHAPE) | |
# inputs = base_model.layers[0] | |
m_layer = base_model.layers[1](inputs) | |
m_layer = base_model.layers[2](m_layer) | |
outputs = keras.layers.Conv2D(64,(3,3), activation='relu')(m_layer) | |
api_model = keras.Model(inputs=inputs, outputs=outputs) | |
# print out model summary | |
api_model.summary() | |
# print(api_model.layers[0].weights) | |
# Both approaches turn out the same | |
""" | |
You can find details of each model in the keras applications here: | |
https://github.com/keras-team/keras-applications | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment