Skip to content

Instantly share code, notes, and snippets.

@dansbecker
Created May 23, 2016 20:32
Show Gist options
  • Save dansbecker/361335eaf9fa5f8243ee95ae5f8f1046 to your computer and use it in GitHub Desktop.
Save dansbecker/361335eaf9fa5f8243ee95ae5f8f1046 to your computer and use it in GitHub Desktop.
Gist showing a weird caveat in model saving and model loading in keras. Uncommenting lines 25-26 causes an error.
from keras.layers import Input, Dense, merge, Flatten
from keras.layers.convolutional import Convolution2D
from keras.models import Model, model_from_json
from keras import backend as K
def make_model(img_edge_size):
img_shape = (3, img_edge_size, img_edge_size)
mask_shape = (1, img_edge_size, img_edge_size)
generic_img = Input(shape=img_shape)
layer = Convolution2D(20, 3, 3, activation='relu', border_mode='same')(generic_img)
reusable_img_featurizer = Model(generic_img, layer)
start_img = Input(shape=img_shape, name='start_img')
start_img_features = reusable_img_featurizer(start_img)
end_img = Input(shape=img_shape, name='end_img')
end_img_features = reusable_img_featurizer(end_img)
start_box_mask = Input(shape=mask_shape, name='start_box_mask')
start_img_features = merge([start_img_features, start_box_mask],
mode='concat', concat_axis=1)
#start_img_features = Convolution2D(40, 3, 3, activation='relu', border_mode='same',
# dim_ordering='th')(start_img_features)
start_img_features = Flatten()(start_img_features)
end_img_features = Flatten()(end_img_features)
layer = merge([start_img_features, end_img_features],
mode='concat',
concat_axis=1)
x0 = Dense(1, activation='linear', name='x0')(layer)
my_model = Model(input=[start_img, end_img, start_box_mask],
output=[x0])
return my_model
def load_model(model_spec, weights_fname):
print('Loading saved model')
my_model = model_from_json(open(model_spec, 'r').read())
my_model.load_weights(weights_fname)
return my_model
def save_model(my_model, model_spec, weights_fname):
print('Saving model')
json_string = my_model.to_json()
open(model_spec, 'w').write(json_string)
my_model.save_weights(weights_fname, overwrite=True)
if __name__ == "__main__":
img_edge_size = 100
weights_fname = 'model_weights.h5'
model_spec = 'model_architecture.json'
my_model = make_model(img_edge_size)
save_model(my_model, model_spec, weights_fname)
load_model(model_spec, weights_fname)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment