Skip to content

Instantly share code, notes, and snippets.

@emuccino
Last active April 15, 2020 20:00
Show Gist options
  • Save emuccino/d7b07999254f0c664774ebd9cfb71d32 to your computer and use it in GitHub Desktop.
Save emuccino/d7b07999254f0c664774ebd9cfb71d32 to your computer and use it in GitHub Desktop.
Compile GAN
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input,Dense,BatchNormalization,Concatenate,GaussianNoise
from tensorflow.keras.optimizers import Nadam
#define latent dimension size
latent_dim = int(np.ceil(np.log(len(train_df)*len(data))))
#function for building generator network
def compile_generator():
#input for random seed vecotors
latent_inputs = Input(shape=(latent_dim,),name='latent')
#input for target specification
target_inputs = Input(shape=(2,),name='target')
inputs = {'latent':latent_inputs,'target':target_inputs}
net = Concatenate()([latent_inputs,target_inputs])
for _ in range(2):
net = Dense(32+len(data), activation='relu',
kernel_initializer='he_uniform')(net)
net = BatchNormalization()(net)
outputs = {}
#numeric data outputs
for name in numeric_data:
outputs[name] = Dense(1,activation='tanh',
kernel_initializer='glorot_uniform',name=name)(net)
string_nets = Dense(len(string_data),activation='relu',
kernel_initializer='he_uniform')(net)
string_nets = BatchNormalization()(string_nets)
#string data outputs
for name,n_token in n_tokens.items():
string_net = Dense(n_embeddings[name],activation='relu',
kernel_initializer='he_uniform')(string_nets)
string_net = BatchNormalization()(string_net)
outputs[name] = Dense(n_token,activation='softmax',
kernel_initializer='glorot_uniform',name=name)(net)
generator = Model(inputs=inputs, outputs=outputs)
generator.compile(loss='categorical_crossentropy',
optimizer=Nadam(clipnorm=1.))
return generator
#function for building discriminator network
def compile_discriminator():
inputs = {}
numeric_nets = []
string_nets = []
#numerica data inputs
for name in numeric_data:
numeric_input = Input(shape=(1,),name=name)
inputs[name] = numeric_input
numeric_net = GaussianNoise(0.01)(numeric_input)
numeric_nets.append(numeric_net)
#string data inputs
for name,n_token in n_tokens.items():
string_input = Input(shape=(n_token,),name=name)
inputs[name] = string_input
string_net = GaussianNoise(0.05)(string_input)
string_net = Dense(n_embeddings[name],activation='relu',
kernel_initializer='he_uniform')(string_net)
string_nets.append(string_net)
string_nets = Concatenate()(string_nets)
string_nets = BatchNormalization()(string_nets)
string_nets = [Dense(len(string_data),activation='relu',
kernel_initializer='he_uniform')(string_nets)]
net = Concatenate()(numeric_nets + string_nets)
net = BatchNormalization()(net)
for _ in range(2):
net = Dense(64+len(data), activation='relu',
kernel_initializer='he_uniform')(net)
net = BatchNormalization()(net)
#discrimination/classification
outputs = Dense(3, activation='softmax',
kernel_initializer='glorot_uniform')(net)
discriminator = Model(inputs=inputs, outputs=outputs)
discriminator.compile(loss='categorical_crossentropy',
optimizer=Nadam(clipnorm=1.))
return discriminator
#function for building GAN network
def compile_gan():
#disable discriminator training
discriminator.trainable = False
#input for random seed vecotors
latent_inputs = Input(shape=(latent_dim,),name='latent')
#input for target specification
target_inputs = Input(shape=(2,),name='target')
inputs = {'latent':latent_inputs,'target':target_inputs}
net = generator([latent_inputs, target_inputs])
outputs = discriminator(net)
gan = Model(inputs=inputs, outputs=outputs)
gan.compile(loss='categorical_crossentropy', optimizer=Nadam(clipnorm=1.))
return gan
#compile GAN network
generator = compile_generator()
discriminator = compile_discriminator()
gan = compile_gan()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment