Skip to content

Instantly share code, notes, and snippets.

@kmcnaught
Last active February 21, 2018 17:09
Show Gist options
  • Save kmcnaught/a335bb26afa66677d8dff9abbf8af138 to your computer and use it in GitHub Desktop.
Save kmcnaught/a335bb26afa66677d8dff9abbf8af138 to your computer and use it in GitHub Desktop.
Demo of serialisation bug in keras with loss_weights as variables
from keras import losses
from keras.models import Model, load_model
from keras.layers import Input, Dense
import keras.backend as K
from keras import optimizers
def build_dense_model_two_outputs(dim_in, dim_hidden, dim_out):
model_input = Input(
shape=(dim_in,),
name='Input')
hidden = Dense(units=dim_hidden,
activation=K.relu)(model_input)
model_output_1 = Dense(units=dim_out)(hidden)
model_output_2 = Dense(units=dim_out)(hidden)
model = Model(inputs=model_input,
outputs=[model_output_1, model_output_2])
return model
def save_and_reload_model(model):
# Tests only that saving and loading doesn't throw, not
# that it's done so correctly :)
model.save('test.h5')
del model
model = load_model('test.h5')
def test_scalar_weights():
# This works:
model = build_dense_model_two_outputs(2, 3, 4)
loss_weights = [1.0, 0.5]
model.compile(loss=[losses.mse, losses.mse],
loss_weights=loss_weights,
optimizer='adam')
save_and_reload_model(model)
def test_variable_weights():
# This does not work (but training with these weights
# would work)
model = build_dense_model_two_outputs(2, 3, 4)
loss_weights = [K.variable(1.0), K.variable(0.5)]
model.compile(loss=[losses.mse, losses.mse],
loss_weights=loss_weights,
optimizer='adam')
save_and_reload_model(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment