-
-
Save BlueskyFR/94337cf7fe8568ae43969b47d6884432 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# To add a new cell, type '# %%' | |
# To add a new markdown cell, type '# %% [markdown]' | |
# %% | |
import time | |
global_start_time = time.time() | |
# %% | |
import tensorflow as tf | |
print(f"✨ Using TensorFlow {tf.__version__}!") | |
for device in tf.config.experimental.list_physical_devices('GPU'): | |
tf.config.experimental.set_memory_growth(device, True) | |
from tensorflow.keras.layers import Resizing, Rescaling, Dense, BatchNormalization, LeakyReLU, Conv2DTranspose, Reshape, Conv2D, Dropout, Flatten | |
import matplotlib.pyplot as plt | |
from IPython import display | |
from pathlib import Path | |
import imageio | |
import glob | |
def plot(img): | |
plt.imshow((img + 1) / 2) | |
# %% [markdown] | |
# # Load the data | |
# %% | |
DATA_DIR = Path("./cats/") | |
IMG_SIZE = (128, 128) | |
# List images | |
dataset = tf.data.Dataset.list_files((DATA_DIR / "**/*.jpg").as_posix()) | |
# Load images | |
dataset = dataset.map( | |
lambda file: tf.io.decode_jpeg( | |
tf.io.read_file(file) | |
) | |
) | |
# Preprocessing | |
preprocess = tf.keras.Sequential([ | |
Resizing(*IMG_SIZE), | |
Rescaling(scale=1. / 127.5, offset=-1) # Normalize from [0, 255] to [-1, 1] | |
]) | |
dataset = dataset.map(lambda img: preprocess(img)) | |
print(f"Loaded dataset of {len(dataset)} cats!") | |
for i in dataset.take(1): | |
plot(i) # Rescale to [0, 1] for imshow | |
# %% | |
# (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data() | |
# train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype("float32") | |
# train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1] | |
# %% [markdown] | |
# ## Batch, cache and shuffle the dataset | |
# %% | |
BUFFER_SIZE = len(dataset) | |
BATCH_SIZE = 256 | |
# Batch and shuffle the data | |
# train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) | |
dataset = dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE) | |
print(f"Each epoch will contain {len(dataset)} batches!") | |
# %% [markdown] | |
# # Create the models | |
# ## The Generator | |
# %% | |
generator = tf.keras.Sequential([ | |
Dense(8 * 8 * 256, use_bias=False, input_shape=(100,)), | |
BatchNormalization(), | |
LeakyReLU(), | |
Reshape((8, 8, 256)), | |
# Transpose = Deconv = Upsampling | |
Conv2DTranspose(filters=1024, kernel_size=5, strides=1, padding="same", use_bias=False), | |
# Output shape: (None, 8, 8, 1024); None is the batch size | |
BatchNormalization(), | |
LeakyReLU(), | |
Conv2DTranspose(filters=512, kernel_size=5, strides=2, padding="same", use_bias=False), | |
# Output shape: (None, 16, 16, 512) | |
BatchNormalization(), | |
LeakyReLU(), | |
Conv2DTranspose(filters=256, kernel_size=5, strides=1, padding="same", use_bias=False), | |
# Output shape: (None, 16, 16, 256) | |
BatchNormalization(), | |
LeakyReLU(), | |
Conv2DTranspose(filters=128, kernel_size=5, strides=2, padding="same", use_bias=False), | |
# Output shape: (None, 32, 32, 128) | |
BatchNormalization(), | |
LeakyReLU(), | |
Conv2DTranspose(filters=64, kernel_size=5, strides=2, padding="same", use_bias=False), | |
# Output shape: (None, 64, 64, 64) | |
BatchNormalization(), | |
LeakyReLU(), | |
Conv2DTranspose(filters=3, kernel_size=5, strides=2, padding="same", use_bias=False, activation="tanh") | |
# Output shape: (None, 128, 128, 3) | |
]) | |
# %% [markdown] | |
# ### Test the Generator | |
# %% | |
noise = tf.random.normal((1, 100)) | |
print(f"Noise shape: {noise.shape}") | |
generated_image = generator(noise, training=False) # training=False prevents callbacks from being called + runs in inference mode (batchnorm) | |
print(generated_image.shape) | |
plot(generated_image[0])#, cmap="gray") | |
# %% [markdown] | |
# ## The Discriminator | |
# | |
# Classifies the images as real or fake. Positive is real, negative is fake. | |
# %% | |
discriminator = tf.keras.Sequential([ | |
Conv2D(filters=64, kernel_size=5, strides=2, padding="same", input_shape=(*IMG_SIZE, 3)), | |
# Output shape: (None, 64, 64, 64) | |
LeakyReLU(), | |
Dropout(0.3), | |
Conv2D(filters=128, kernel_size=5, strides=2, padding="same"), | |
# Output shape: (None, 32, 32, 128) | |
LeakyReLU(), | |
Dropout(0.3), | |
Conv2D(filters=256, kernel_size=5, strides=2, padding="same"), | |
# Output shape: (None, 16, 16, 256) | |
LeakyReLU(), | |
Dropout(0.3), | |
Conv2D(filters=512, kernel_size=5, strides=1, padding="same"), | |
# Output shape: (None, 16, 16, 512) | |
LeakyReLU(), | |
Dropout(0.3), | |
Conv2D(filters=1024, kernel_size=5, strides=1, padding="same"), | |
# Output shape: (None, 8, 8, 1024) | |
LeakyReLU(), | |
Dropout(0.3), | |
Flatten(), | |
Dense(1) | |
]) | |
# %% [markdown] | |
# ### Test the Discriminator on the previously generated image | |
# %% | |
decision = discriminator(generated_image, training=False) | |
print(decision) | |
# %% [markdown] | |
# Because of the random biases initialization, the output is close to 0. | |
# | |
# # Loss and optimizers | |
# %% | |
# Helper function to compute the cross entropy loss | |
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) | |
# %% [markdown] | |
# ## Discriminator loss | |
# | |
# Each time, the discriminator will receive batches of both real and a fake images. | |
# The output for a batch of real images should be an array of 1s, and an array of 0s for a fake one. | |
# %% | |
def discriminator_loss(real_output, fake_output): | |
real_loss = cross_entropy(tf.ones_like(real_output), real_output) | |
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) | |
return real_loss + fake_loss | |
# %% [markdown] | |
# ## Generator loss | |
# | |
# The generator loss quantifies how well it was able to trick the discriminator. If the generator is performing well, the discriminator will classify the fake images (i.e. the generated ones) as real, so as an array on 1s. | |
# %% | |
def generator_loss(fake_output): | |
return cross_entropy(tf.ones_like(fake_output), fake_output) | |
# %% [markdown] | |
# ## Optimizers | |
# %% | |
generator_optimizer = tf.keras.optimizers.Adam(1e-4) | |
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) | |
# %% [markdown] | |
# # Define the training loop | |
# %% | |
EPOCHS = 5#1000 | |
noise_dim = 100 | |
num_examples_to_generate = 16 | |
# A seed periodically used to generate a nice gif and visualize progression | |
demo_gif_seed = tf.random.normal((num_examples_to_generate, noise_dim)) | |
# We use tf.function so that the function is "compiled" through the TensorFlow graph | |
@tf.function | |
def train_step(real_images): | |
noise = tf.random.normal((BATCH_SIZE, noise_dim)) | |
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: | |
fake_images = generator(noise, training=True) | |
real_output = discriminator(real_images, training=True) | |
fake_output = discriminator(fake_images, training=True) | |
gen_loss = generator_loss(fake_output) | |
disc_loss = discriminator_loss(real_output, fake_output) | |
generator_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables) | |
discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables) | |
generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables)) | |
discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables)) | |
# %% | |
def train(dataset, epochs): | |
for epoch in range(epochs): | |
start = time.time() | |
for image_batch in dataset: | |
train_step(image_batch) | |
# After iterating through the entire dataset, save a snapshot | |
# of the current generator state to generate a GIF later | |
#display.clear_output(wait=True) | |
#generate_and_save_images(generator, epoch + 1, demo_gif_seed) | |
print(f"Time for epoch {epoch + 1} is {time.time() - start}") | |
#display.clear_output(wait=True) | |
#generate_and_save_images(generator, epochs, demo_gif_seed) | |
# %% | |
def generate_and_save_images(model, epoch, test_input): | |
predictions = model(test_input, training=False) | |
plt.figure(figsize=(4, 4)) | |
for i in range(predictions.shape[0]): | |
plt.subplot(4, 4, i + 1) | |
# plt.imshow(predictions[i] * 127.5 + 127.5)#, cmap="gray") | |
plot(predictions[i]) | |
plt.axis("off") | |
plt.savefig(f"image_at_epoch_{epoch:03d}.png") | |
plt.show() | |
# %% [markdown] | |
# # Train the model | |
# %% | |
train(dataset, EPOCHS) | |
# %% [markdown] | |
# # Create a GIF | |
# %% | |
# gif_file = "gan.gif" | |
# with imageio.get_writer(gif_file, mode='I') as writer: | |
# for filename in sorted(glob.glob("image*.png")): | |
# image = imageio.imread(filename) | |
# writer.append_data(image) | |
# image = imageio.imread(gif_file) | |
# writer.append_data(image) | |
print(f"Total run time (in seconds): {time.time() - global_start_time}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment