Created
May 23, 2016 20:32
-
-
Save dansbecker/361335eaf9fa5f8243ee95ae5f8f1046 to your computer and use it in GitHub Desktop.
Gist showing a weird caveat in model saving and model loading in keras. Uncommenting lines 25-26 causes an error.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Input, Dense, merge, Flatten | |
from keras.layers.convolutional import Convolution2D | |
from keras.models import Model, model_from_json | |
from keras import backend as K | |
def make_model(img_edge_size): | |
img_shape = (3, img_edge_size, img_edge_size) | |
mask_shape = (1, img_edge_size, img_edge_size) | |
generic_img = Input(shape=img_shape) | |
layer = Convolution2D(20, 3, 3, activation='relu', border_mode='same')(generic_img) | |
reusable_img_featurizer = Model(generic_img, layer) | |
start_img = Input(shape=img_shape, name='start_img') | |
start_img_features = reusable_img_featurizer(start_img) | |
end_img = Input(shape=img_shape, name='end_img') | |
end_img_features = reusable_img_featurizer(end_img) | |
start_box_mask = Input(shape=mask_shape, name='start_box_mask') | |
start_img_features = merge([start_img_features, start_box_mask], | |
mode='concat', concat_axis=1) | |
#start_img_features = Convolution2D(40, 3, 3, activation='relu', border_mode='same', | |
# dim_ordering='th')(start_img_features) | |
start_img_features = Flatten()(start_img_features) | |
end_img_features = Flatten()(end_img_features) | |
layer = merge([start_img_features, end_img_features], | |
mode='concat', | |
concat_axis=1) | |
x0 = Dense(1, activation='linear', name='x0')(layer) | |
my_model = Model(input=[start_img, end_img, start_box_mask], | |
output=[x0]) | |
return my_model | |
def load_model(model_spec, weights_fname): | |
print('Loading saved model') | |
my_model = model_from_json(open(model_spec, 'r').read()) | |
my_model.load_weights(weights_fname) | |
return my_model | |
def save_model(my_model, model_spec, weights_fname): | |
print('Saving model') | |
json_string = my_model.to_json() | |
open(model_spec, 'w').write(json_string) | |
my_model.save_weights(weights_fname, overwrite=True) | |
if __name__ == "__main__": | |
img_edge_size = 100 | |
weights_fname = 'model_weights.h5' | |
model_spec = 'model_architecture.json' | |
my_model = make_model(img_edge_size) | |
save_model(my_model, model_spec, weights_fname) | |
load_model(model_spec, weights_fname) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment