Created
July 2, 2019 16:18
-
-
Save Eniwder/81665eac3d5af7d37be30af70ea1dc65 to your computer and use it in GitHub Desktop.
あるmodelから別のmodelへの入力 想定はConditinal GAN
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # 前提1 | |
| def generator_model(): | |
| input_layer = Input(shape=(100)) | |
| # ~~~~ | |
| output_layer = Activation('tanh')(layer_x) | |
| return model = Model(input_layer, output_layer) # shape=(batchN,28,28,1) | |
| # 前提2 | |
| def discriminator_model(): | |
| input_layer = Input(shape=(28, 28, 11)) | |
| # ~~~~ | |
| output_layer = Dense(1, activation="sigmoid")(layer_x) | |
| return Model(input_layer, output_layer) # shape=(1) | |
| # この時、generator_modelの出力を加工してdiscriminator_modelへ入力するようなモデルを作りたい | |
| def conbined_model(gen_model, dis_model): | |
| out_gen = gen_model(input_gen) # こんな感じでmodelに対し()をつけて呼び出せばOK // shape=(batchN,28,28,1) | |
| input_dis = Concatenate(axis=3)([out_gen, input_label]) # 出力を加工 // shape=(batchN,28,28,11) | |
| out_dis = dis_model(input_dis) | |
| return Model([input_gen, input_label], out_dis) # 入力を複数受け取る場合はこんな感じの指定 | |
| # usage | |
| gen_model = generator_model() | |
| dis_model = discriminator_model() | |
| cgan_model = conbined_model(gen_model, dis_model) | |
| # Xn,ynは学習用のデータとして定義済みの想定 | |
| dis_model.train_on_batch(X1, y1) | |
| cgan_model = conbined_model([X2_1,X2_2], y2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment