Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created May 30, 2023 08:45
Show Gist options
  • Select an option

  • Save pythonlessons/7f4d8838d08d874d18579e7d17499767 to your computer and use it in GitHub Desktop.

Select an option

Save pythonlessons/7f4d8838d08d874d18579e7d17499767 to your computer and use it in GitHub Desktop.
wgan_gp
import tensorflow as tf
from keras import layers
# Define the generator model
def build_generator(noise_dim, output_channels=3, activation="tanh", alpha=0.2):
inputs = layers.Input(shape=noise_dim, name="input")
x = layers.Dense(4*4*512, use_bias=False)(inputs)
x = layers.Reshape((4, 4, 512))(x)
x = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding="same", use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(alpha)(x)
x = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding="same", use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(alpha)(x)
x = layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding="same", use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(alpha)(x)
x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding="same", use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(alpha)(x)
x = layers.Dropout(0.5)(x)
x = layers.Conv2D(output_channels, (5, 5), strides=(1, 1), padding="same", activation=activation, use_bias=False, dtype='float32')(x)
assert x.shape == (None, 64, 64, output_channels)
model = tf.keras.Model(inputs=inputs, outputs=x)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment