Last active
April 15, 2020 20:00
-
-
Save emuccino/d7b07999254f0c664774ebd9cfb71d32 to your computer and use it in GitHub Desktop.
Compile GAN
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from tensorflow.keras.models import Model | |
| from tensorflow.keras.layers import Input,Dense,BatchNormalization,Concatenate,GaussianNoise | |
| from tensorflow.keras.optimizers import Nadam | |
| #define latent dimension size | |
| latent_dim = int(np.ceil(np.log(len(train_df)*len(data)))) | |
| #function for building generator network | |
| def compile_generator(): | |
| #input for random seed vecotors | |
| latent_inputs = Input(shape=(latent_dim,),name='latent') | |
| #input for target specification | |
| target_inputs = Input(shape=(2,),name='target') | |
| inputs = {'latent':latent_inputs,'target':target_inputs} | |
| net = Concatenate()([latent_inputs,target_inputs]) | |
| for _ in range(2): | |
| net = Dense(32+len(data), activation='relu', | |
| kernel_initializer='he_uniform')(net) | |
| net = BatchNormalization()(net) | |
| outputs = {} | |
| #numeric data outputs | |
| for name in numeric_data: | |
| outputs[name] = Dense(1,activation='tanh', | |
| kernel_initializer='glorot_uniform',name=name)(net) | |
| string_nets = Dense(len(string_data),activation='relu', | |
| kernel_initializer='he_uniform')(net) | |
| string_nets = BatchNormalization()(string_nets) | |
| #string data outputs | |
| for name,n_token in n_tokens.items(): | |
| string_net = Dense(n_embeddings[name],activation='relu', | |
| kernel_initializer='he_uniform')(string_nets) | |
| string_net = BatchNormalization()(string_net) | |
| outputs[name] = Dense(n_token,activation='softmax', | |
| kernel_initializer='glorot_uniform',name=name)(net) | |
| generator = Model(inputs=inputs, outputs=outputs) | |
| generator.compile(loss='categorical_crossentropy', | |
| optimizer=Nadam(clipnorm=1.)) | |
| return generator | |
| #function for building discriminator network | |
| def compile_discriminator(): | |
| inputs = {} | |
| numeric_nets = [] | |
| string_nets = [] | |
| #numerica data inputs | |
| for name in numeric_data: | |
| numeric_input = Input(shape=(1,),name=name) | |
| inputs[name] = numeric_input | |
| numeric_net = GaussianNoise(0.01)(numeric_input) | |
| numeric_nets.append(numeric_net) | |
| #string data inputs | |
| for name,n_token in n_tokens.items(): | |
| string_input = Input(shape=(n_token,),name=name) | |
| inputs[name] = string_input | |
| string_net = GaussianNoise(0.05)(string_input) | |
| string_net = Dense(n_embeddings[name],activation='relu', | |
| kernel_initializer='he_uniform')(string_net) | |
| string_nets.append(string_net) | |
| string_nets = Concatenate()(string_nets) | |
| string_nets = BatchNormalization()(string_nets) | |
| string_nets = [Dense(len(string_data),activation='relu', | |
| kernel_initializer='he_uniform')(string_nets)] | |
| net = Concatenate()(numeric_nets + string_nets) | |
| net = BatchNormalization()(net) | |
| for _ in range(2): | |
| net = Dense(64+len(data), activation='relu', | |
| kernel_initializer='he_uniform')(net) | |
| net = BatchNormalization()(net) | |
| #discrimination/classification | |
| outputs = Dense(3, activation='softmax', | |
| kernel_initializer='glorot_uniform')(net) | |
| discriminator = Model(inputs=inputs, outputs=outputs) | |
| discriminator.compile(loss='categorical_crossentropy', | |
| optimizer=Nadam(clipnorm=1.)) | |
| return discriminator | |
| #function for building GAN network | |
| def compile_gan(): | |
| #disable discriminator training | |
| discriminator.trainable = False | |
| #input for random seed vecotors | |
| latent_inputs = Input(shape=(latent_dim,),name='latent') | |
| #input for target specification | |
| target_inputs = Input(shape=(2,),name='target') | |
| inputs = {'latent':latent_inputs,'target':target_inputs} | |
| net = generator([latent_inputs, target_inputs]) | |
| outputs = discriminator(net) | |
| gan = Model(inputs=inputs, outputs=outputs) | |
| gan.compile(loss='categorical_crossentropy', optimizer=Nadam(clipnorm=1.)) | |
| return gan | |
| #compile GAN network | |
| generator = compile_generator() | |
| discriminator = compile_discriminator() | |
| gan = compile_gan() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment