Skip to content

Instantly share code, notes, and snippets.

@deeperunderstanding
Last active July 25, 2019 20:18
Show Gist options
  • Save deeperunderstanding/ccc40a7bfc86614c2aea8501f248b9bf to your computer and use it in GitHub Desktop.
Save deeperunderstanding/ccc40a7bfc86614c2aea8501f248b9bf to your computer and use it in GitHub Desktop.
window_size = train_x.shape[1]
input_dim = train_x.shape[2]
latent_dim = 32
cat_dim = 8
prior_discriminator = create_discriminator(latent_dim)
prior_discriminator.compile(loss='binary_crossentropy',
optimizer=Nadam(0.0002, 0.5),
metrics=['accuracy'])
prior_discriminator.trainable = False
cat_discriminator = create_discriminator(cat_dim)
cat_discriminator.compile(loss='binary_crossentropy',
optimizer=Nadam(0.0002, 0.5),
metrics=['accuracy'])
cat_discriminator.trainable = False
encoder = create_encoder(latent_dim, cat_dim, window_size, input_dim)
signal_in = Input(shape=(window_size, input_dim))
reconstructed_signal, encoded_repr, category, _ = encoder(signal_in)
is_real_prior = prior_discriminator(encoded_repr)
is_real_cat = cat_discriminator(category)
autoencoder = Model(signal_in, [reconstructed_signal, is_real_prior, is_real_cat])
autoencoder.compile(loss=['mse', 'binary_crossentropy', 'binary_crossentropy'],
loss_weights=[0.99, 0.005, 0.005],
optimizer=Nadam(0.0002, 0.5))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment