Created
March 6, 2019 22:05
-
-
Save phisad/4d8b7fd86cb44f66132648d91052b1d8 to your computer and use it in GitHub Desktop.
Keras disconnected graphs example when using multiple models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_connected_models(self): | |
input1 = Input(shape=(100,)) | |
dense1 = Dense(1)(input1) | |
model1 = Model(input1, dense1) | |
input2 = Input(shape=(200,)) | |
dense2 = Dense(2)(input2) | |
model2 = Model(input2, dense2) | |
# This will work, because there are no intermediate Inputs. | |
# The best solution for a complex graph is not to use intermediate Models | |
# but to use the Funtional API and only produce one model at the end | |
dense3 = Concatenate()([model1.layers[-1].output, model2.layers[-1].output]) | |
model3 = Model(inputs=[model1.input, model2.input], outputs=dense3) | |
print(model3.summary()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_disconnected_graph(self): | |
input1 = Input(shape=(100,)) | |
dense1 = Dense(1)(input1) | |
model1 = Model(input1, dense1) | |
input2 = Input(shape=(200,)) | |
dense2 = Dense(2)(input2) | |
model2 = Model(input2, dense2) | |
input11 = Input(shape=(1,)) | |
input12 = Input(shape=(2,)) | |
dense3 = Concatenate()([input11, input12]) | |
model3 = Model(inputs=[input11, input12], outputs=dense3) | |
print(model3.summary()) | |
# We cannot use model3.output here, because the model3 graph is disconnected | |
# from this graph by defining its own inputs input11 and input12 | |
model123 = model3.layers[-1].output | |
model123 = Model(inputs=[model1.input, model2.input], outputs=model123) | |
print(model123.summary()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment