Last active
February 21, 2018 17:09
-
-
Save kmcnaught/a335bb26afa66677d8dff9abbf8af138 to your computer and use it in GitHub Desktop.
Demo of serialisation bug in keras with loss_weights as variables
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras import losses | |
from keras.models import Model, load_model | |
from keras.layers import Input, Dense | |
import keras.backend as K | |
from keras import optimizers | |
def build_dense_model_two_outputs(dim_in, dim_hidden, dim_out): | |
model_input = Input( | |
shape=(dim_in,), | |
name='Input') | |
hidden = Dense(units=dim_hidden, | |
activation=K.relu)(model_input) | |
model_output_1 = Dense(units=dim_out)(hidden) | |
model_output_2 = Dense(units=dim_out)(hidden) | |
model = Model(inputs=model_input, | |
outputs=[model_output_1, model_output_2]) | |
return model | |
def save_and_reload_model(model): | |
# Tests only that saving and loading doesn't throw, not | |
# that it's done so correctly :) | |
model.save('test.h5') | |
del model | |
model = load_model('test.h5') | |
def test_scalar_weights(): | |
# This works: | |
model = build_dense_model_two_outputs(2, 3, 4) | |
loss_weights = [1.0, 0.5] | |
model.compile(loss=[losses.mse, losses.mse], | |
loss_weights=loss_weights, | |
optimizer='adam') | |
save_and_reload_model(model) | |
def test_variable_weights(): | |
# This does not work (but training with these weights | |
# would work) | |
model = build_dense_model_two_outputs(2, 3, 4) | |
loss_weights = [K.variable(1.0), K.variable(0.5)] | |
model.compile(loss=[losses.mse, losses.mse], | |
loss_weights=loss_weights, | |
optimizer='adam') | |
save_and_reload_model(model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment