Skip to content

Instantly share code, notes, and snippets.

@somewacko
Created April 15, 2016 02:48
Show Gist options
  • Save somewacko/41ce4cccfcfe8f0bff1b4ad82d0ee451 to your computer and use it in GitHub Desktop.
Save somewacko/41ce4cccfcfe8f0bff1b4ad82d0ee451 to your computer and use it in GitHub Desktop.
Keras VGG-16 model with functional API
from keras.layers import Dense, Dropout, Flatten, Input
from keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D
from keras.models import Model
from keras.utils.layer_utils import print_summary
VGG_PATH = 'vgg16_weights.h5'
vis_input = Input(shape=visual_shape, name="vis_input")
x = ZeroPadding2D((1,1)) (vis_input)
x = Convolution2D(64, 3, 3, activation='relu') (x)
x = ZeroPadding2D((1,1)) (x)
x = Convolution2D(64, 3, 3, activation='relu') (x)
x = MaxPooling2D((2,2), strides=(2,2)) (x)
x = ZeroPadding2D((1,1)) (x)
x = Convolution2D(128, 3, 3, activation='relu') (x)
x = ZeroPadding2D((1,1)) (x)
x = Convolution2D(128, 3, 3, activation='relu') (x)
x = MaxPooling2D((2,2), strides=(2,2)) (x)
x = ZeroPadding2D((1,1)) (x)
x = Convolution2D(256, 3, 3, activation='relu') (x)
x = ZeroPadding2D((1,1)) (x)
x = Convolution2D(256, 3, 3, activation='relu') (x)
x = ZeroPadding2D((1,1)) (x)
x = Convolution2D(256, 3, 3, activation='relu') (x)
x = MaxPooling2D((2,2), strides=(2,2)) (x)
x = ZeroPadding2D((1,1)) (x)
x = Convolution2D(512, 3, 3, activation='relu') (x)
x = ZeroPadding2D((1,1)) (x)
x = Convolution2D(512, 3, 3, activation='relu') (x)
x = ZeroPadding2D((1,1)) (x)
x = Convolution2D(512, 3, 3, activation='relu') (x)
x = MaxPooling2D((2,2), strides=(2,2)) (x)
x = ZeroPadding2D((1,1)) (x)
x = Convolution2D(512, 3, 3, activation='relu') (x)
x = ZeroPadding2D((1,1)) (x)
x = Convolution2D(512, 3, 3, activation='relu') (x)
x = ZeroPadding2D((1,1)) (x)
x = Convolution2D(512, 3, 3, activation='relu') (x)
x = MaxPooling2D((2,2), strides=(2,2)) (x)
x = Flatten() (x)
x = Dense(4096, activation='relu') (x)
x = Dropout(0.5) (x)
x = Dense(4096, activation='relu') (x)
x = Dropout(0.5) (x)
x = Dense(1000, activation='relu') (x)
model = Model(input=vis_input, output=x)
print_summary(model.layers)
model.load_weights(vgg_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment