This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def show_generator_output(sess, n_images, input_z, out_channel_dim, image_mode, image_path, save, show): | |
""" | |
Show example output for the generator | |
:param sess: TensorFlow session | |
:param n_images: Number of Images to display | |
:param input_z: Input Z Tensor | |
:param out_channel_dim: The number of channels in the output image | |
:param image_mode: The mode to use for images ("RGB" or "L") | |
:param image_path: Path to save the image | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(epoch_count, batch_size, z_dim, learning_rate_D, learning_rate_G, beta1, get_batches, data_shape, data_image_mode, alpha): | |
""" | |
Train the GAN | |
:param epoch_count: Number of epochs | |
:param batch_size: Batch Size | |
:param z_dim: Z dimension | |
:param learning_rate: Learning Rate | |
:param beta1: The exponential decay rate for the 1st moment in the optimizer | |
:param get_batches: Function to get batches | |
:param data_shape: Shape of the data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def discriminator(x, is_reuse=False, alpha = 0.2): | |
''' Build the discriminator network. | |
Arguments | |
--------- | |
x : Input tensor for the discriminator | |
n_units: Number of units in hidden layer | |
reuse : Reuse the variables with tf.variable_scope | |
alpha : leak parameter for leaky ReLU | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generator(z, output_channel_dim, is_train=True): | |
''' Build the generator network. | |
Arguments | |
--------- | |
z : Input tensor for the generator | |
output_channel_dim : Shape of the generator output | |
n_units : Number of units in hidden layer | |
reuse : Reuse the variables with tf.variable_scope | |
alpha : leak parameter for leaky ReLU |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Size input image for discriminator | |
real_size = (128,128,3) | |
# Size of latent vector to generator | |
z_dim = 100 | |
learning_rate_D = .00005 # Thanks to Alexia Jolicoeur Martineau https://ajolicoeur.wordpress.com/cats/ | |
learning_rate_G = 2e-4 # Thanks to Alexia Jolicoeur Martineau https://ajolicoeur.wordpress.com/cats/ | |
batch_size = 64 | |
epochs = 215 | |
alpha = 0.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model_inputs(real_dim, z_dim): | |
""" | |
Create the model inputs | |
:param real_dim: tuple containing width, height and channels | |
:param z_dim: The dimension of Z | |
:return: Tuple of (tensor of real input images, tensor of z data, learning rate G, learning rate D) | |
""" | |
# inputs_real for Discriminator | |
inputs_real = tf.placeholder(tf.float32, (None, *real_dim), name='inputs_real') | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model_loss(input_real, input_z, output_channel_dim, alpha): | |
""" | |
Get the loss for the discriminator and generator | |
:param input_real: Images from the real dataset | |
:param input_z: Z input | |
:param out_channel_dim: The number of channels in the output image | |
:return: A tuple of (discriminator loss, generator loss) | |
""" | |
# Generator network here | |
g_model = generator(input_z, output_channel_dim) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model_optimizers(d_loss, g_loss, lr_D, lr_G, beta1): | |
""" | |
Get optimization operations | |
:param d_loss: Discriminator loss Tensor | |
:param g_loss: Generator loss Tensor | |
:param learning_rate: Learning Rate Placeholder | |
:param beta1: The exponential decay rate for the 1st moment in the optimizer | |
:return: A tuple of (discriminator training operation, generator training operation) | |
""" | |
# Get the trainable_variables, split into G and D parts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load the data and train the network here | |
dataset = helper.Dataset(glob(os.path.join(data_resized_dir, '*.jpg'))) | |
with tf.Graph().as_default(): | |
losses, samples = train(epochs, batch_size, z_dim, learning_rate_D, learning_rate_G, beta1, dataset.get_batches, | |
dataset.shape, dataset.image_mode, alpha) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(epoch_count, batch_size, z_dim, learning_rate_D, learning_rate_G, beta1, get_batches, data_shape, data_image_mode, alpha): | |
""" | |
Train the GAN | |
:param epoch_count: Number of epochs | |
:param batch_size: Batch Size | |
:param z_dim: Z dimension | |
:param learning_rate: Learning Rate | |
:param beta1: The exponential decay rate for the 1st moment in the optimizer | |
:param get_batches: Function to get batches | |
:param data_shape: Shape of the data |
OlderNewer