Skip to content

Instantly share code, notes, and snippets.

@shanecandoit
Created October 24, 2022 22:11
Show Gist options
  • Save shanecandoit/c305d9db30ad6836b8df911b1c0ea7b1 to your computer and use it in GitHub Desktop.
Save shanecandoit/c305d9db30ad6836b8df911b1c0ea7b1 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# conditional gan\n",
"\n",
"https://keras.io/examples/generative/conditional_gan/"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import tensorflow\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"\n",
"from tensorflow_docs.vis import embed\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import imageio\n"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'2.8.0'"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tensorflow.__version__"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"# !pip install -q git+https://github.com/tensorflow/docs\n"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"# consts\n",
"batch_size = 64\n",
"num_channels = 1\n",
"num_classes = 10\n",
"image_size = 28\n",
"latent_dim = 128"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"shape training data: (70000, 28, 28, 1)\n",
"shape training labels: (70000, 10)\n"
]
}
],
"source": [
"# get data\n",
"(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
"all_digits = np.concatenate([x_train, x_test], axis=0)\n",
"all_labels = np.concatenate([y_train, y_test], axis=0)\n",
"\n",
"# scale\n",
"all_digits = all_digits.astype(\"float32\") / 255.0\n",
"all_digits = np.reshape(all_digits, (-1, image_size, image_size, 1))\n",
"all_labels = keras.utils.to_categorical(all_labels, num_classes)\n",
"\n",
"# data set\n",
"dataset = tf.data.Dataset.from_tensor_slices((all_digits, all_labels))\n",
"dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)\n",
"\n",
"print('shape training data: ', all_digits.shape)\n",
"print('shape training labels: ', all_labels.shape)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"138 11\n"
]
}
],
"source": [
"# add category to noise input\n",
"generator_in_channels = latent_dim + num_classes\n",
"discrimintaor_in_channels = num_channels + num_classes\n",
"print(generator_in_channels, discrimintaor_in_channels)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"# discriminator\n",
"discriminator = keras.Sequential(\n",
" [keras.layers.InputLayer((image_size, image_size, discrimintaor_in_channels)),\n",
" keras.layers.Conv2D(64, (3,3), strides=(2,2), padding=\"same\"),\n",
" keras.layers.LeakyReLU(alpha=0.2),\n",
" keras.layers.Conv2D(128, (3,3), strides=(2,2), padding=\"same\"),\n",
" keras.layers.LeakyReLU(alpha=0.2),\n",
" keras.layers.GlobalMaxPooling2D(),\n",
" keras.layers.Dense(1),\n",
" ],name=\"discriminator\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"# generator\n",
"generator = keras.Sequential(\n",
" [keras.layers.InputLayer((generator_in_channels,)),\n",
" keras.layers.Dense(7 * 7 * generator_in_channels),\n",
" keras.layers.LeakyReLU(alpha=0.2),\n",
" layers.Reshape((7, 7, generator_in_channels)),\n",
" keras.layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding=\"same\"),\n",
" keras.layers.LeakyReLU(alpha=0.2),\n",
" keras.layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding=\"same\"),\n",
" keras.layers.LeakyReLU(alpha=0.2),\n",
" keras.layers.Conv2D(1, (7,7), padding=\"same\", activation=\"sigmoid\"),\n",
" ],name=\"generator\",)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"# cond gan model\n",
"class ConditionalGAN(keras.Model):\n",
" def __init__(self, discriminator, generator, latent_dim) -> None:\n",
" super(ConditionalGAN, self).__init__()\n",
" self.discriminator = discriminator\n",
" self.generator = generator\n",
" self.latent_dim = latent_dim\n",
" self.gen_loss_tracker = keras.metrics.Mean(name=\"gener_loss\")\n",
" self.disc_loss_tracker = keras.metrics.Mean(name=\"discrim_loss\")\n",
" \n",
" @property\n",
" def metrics(self):\n",
" return [self.gen_loss_tracker, self.disc_loss_tracker]\n",
" \n",
" def compile(self, d_optimizer, g_optimizer, loss_fn):\n",
" super(ConditionalGAN, self).compile()\n",
" self.d_optimizer = d_optimizer\n",
" self.g_optimizer = g_optimizer\n",
" self.loss_fn = loss_fn\n",
" \n",
" def train_step(self, data):\n",
" # unpack data\n",
" real_images, one_hot_labels = data\n",
"\n",
" # dummy dim\n",
" image_one_hot_labels = one_hot_labels[:, :, None, None]\n",
" image_one_hot_labels = tf.repeat(\n",
" image_one_hot_labels, repeats=[image_size*image_size]\n",
" )\n",
" image_one_hot_labels = tf.reshape(\n",
" image_one_hot_labels, (-1, image_size, image_size, num_classes)\n",
" )\n",
"\n",
" # random points\n",
" batch_size = tf.shape(real_images)[0]\n",
" random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n",
" random_latent_vectors = tf.concat([random_latent_vectors, one_hot_labels], axis=1)\n",
"\n",
" # decode noise\n",
" generated_images = self.generator(random_latent_vectors)\n",
"\n",
" # combine\n",
" fake_image_and_labels = tf.concat([generated_images, image_one_hot_labels], axis=-1)\n",
" real_image_and_labels = tf.concat([real_images, image_one_hot_labels], axis=-1)\n",
" combined_images = tf.concat([fake_image_and_labels, real_image_and_labels], axis=0)\n",
"\n",
" # assmeble labels\n",
" labels = tf.concat(\n",
" [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0\n",
" )\n",
"\n",
" # train discriminator\n",
" with tf.GradientTape() as tape:\n",
" predictions = self.discriminator(combined_images)\n",
" d_loss = self.loss_fn(labels, predictions)\n",
" grads = tape.gradient(d_loss, self.discriminator.trainable_weights)\n",
" self.d_optimizer.apply_gradients(\n",
" zip(grads, self.discriminator.trainable_weights)\n",
" )\n",
"\n",
" # sample\n",
" random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n",
" random_vector_labels = tf.concat(\n",
" [random_latent_vectors, one_hot_labels], axis=1\n",
" )\n",
"\n",
" # assemble labels\n",
" misleading_labels = tf.zeros((batch_size, 1))\n",
"\n",
" # train generator\n",
" with tf.GradientTape() as tape:\n",
" fake_images = self.generator(random_vector_labels)\n",
" fake_image_and_labels = tf.concat([fake_images, image_one_hot_labels], axis=-1)\n",
" predictions = self.discriminator(fake_image_and_labels)\n",
" g_loss = self.loss_fn(misleading_labels, predictions)\n",
" grads = tape.gradient(g_loss, self.generator.trainable_weights)\n",
" self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))\n",
"\n",
" # monitor loss\n",
" self.gen_loss_tracker.update_state(g_loss)\n",
" self.disc_loss_tracker.update_state(d_loss)\n",
" return {\"gener_loss\": self.gen_loss_tracker.result(), \n",
" \"discrim_loss\": self.disc_loss_tracker.result()}"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1094/1094 [==============================] - 1984s 2s/step - gener_loss: 1.4442 - discrim_loss: 0.4190\n",
"Epoch 2/10\n",
"1094/1094 [==============================] - 1838s 2s/step - gener_loss: 1.3046 - discrim_loss: 0.4924\n",
"Epoch 3/10\n",
"1094/1094 [==============================] - 1685s 2s/step - gener_loss: 1.3184 - discrim_loss: 0.4690\n",
"Epoch 4/10\n",
"1094/1094 [==============================] - 1370s 1s/step - gener_loss: 1.6242 - discrim_loss: 0.3623\n",
"Epoch 5/10\n",
"1094/1094 [==============================] - 1370s 1s/step - gener_loss: 1.5247 - discrim_loss: 0.4728\n",
"Epoch 6/10\n",
"1094/1094 [==============================] - 1539s 1s/step - gener_loss: 0.9526 - discrim_loss: 0.6141\n",
"Epoch 7/10\n",
"1094/1094 [==============================] - 1464s 1s/step - gener_loss: 0.8776 - discrim_loss: 0.6420\n",
"Epoch 8/10\n",
"1094/1094 [==============================] - 1817s 2s/step - gener_loss: 0.8361 - discrim_loss: 0.6550\n",
"Epoch 9/10\n",
"1094/1094 [==============================] - 2156s 2s/step - gener_loss: 0.8318 - discrim_loss: 0.6583\n",
"Epoch 10/10\n",
"1094/1094 [==============================] - 1661s 2s/step - gener_loss: 0.8078 - discrim_loss: 0.6661\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f8caf4efd30>"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# train cond gan\n",
"cond_gan = ConditionalGAN(discriminator, generator, latent_dim=latent_dim)\n",
"cond_gan.compile(\n",
" d_optimizer=keras.optimizers.Adam(learning_rate=0.0002),\n",
" g_optimizer=keras.optimizers.Adam(learning_rate=0.0002),\n",
" loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),\n",
")\n",
"\n",
"cond_gan.fit(dataset, epochs=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<BatchDataset element_spec=(TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None, 10), dtype=tf.float32, name=None))>"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"# interpolate\n",
"from configparser import Interpolation\n",
"\n",
"\n",
"trained_gen = cond_gan.generator\n",
"\n",
"# num intermediatee images\n",
"num_interpolations = 9\n",
"\n",
"# sample noise\n",
"interpolation_noise = tf.random.normal(shape=(1, latent_dim))\n",
"interpolation_noise = tf.repeat(interpolation_noise, repeats=num_interpolations,)\n",
"interpolation_noise = tf.reshape(interpolation_noise, (num_interpolations, latent_dim))\n",
"\n",
"def interpolate_class(first_number, second_number):\n",
" first_label = keras.utils.to_categorical([first_number], num_classes)\n",
" second_label = keras.utils.to_categorical([second_number], num_classes)\n",
" first_label = tf.cast(first_label, tf.float32)\n",
" second_label = tf.cast(second_label, tf.float32)\n",
"\n",
" # interpolate\n",
" percent_second = tf.linspace(0.0, 1.0, num_interpolations)[:, None]\n",
" percent_second = tf.cast(percent_second, tf.float32)\n",
" interpolation_labels = (first_label * (1 - percent_second)) + (second_label * percent_second)\n",
"\n",
" # combine\n",
" noise_and_labels = tf.concat([interpolation_noise, interpolation_labels], axis=1)\n",
" fake = trained_gen.predict(noise_and_labels)\n",
" return fake\n",
" \n",
"start_class = 1\n",
"end_class = 5\n",
"\n",
"fake_images = interpolate_class(start_class, end_class)\n"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<img src=\"\"/>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fake_images *= 255.\n",
"conv_images = fake_images.astype(np.uint8)\n",
"conv_images = tf.image.resize(conv_images, (96, 96)).numpy().astype(np.uint8)\n",
"imageio.mimsave('animation.gif', conv_images, fps=1)\n",
"embed.embed_file('animation.gif')"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"drwxr-xr-x@ 11 sandy staff 352B Oct 20 12:51 js_movies\n",
"drwxr-xr-x@ 50 sandy staff 1.6K Oct 20 12:51 RustPython\n",
"-rw-r--r--@ 1 sandy staff 36K Oct 24 13:52 animation.gif\n",
"-rw-r--r-- 1 sandy staff 6.0M Oct 24 16:13 cond_gan_weights.h5\n",
"-rw-r--r--@ 1 sandy staff 65K Oct 24 16:14 condi_gan.ipynb\n"
]
}
],
"source": [
"!ls -ltrh | tail -n5"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"#cond_gan.save('cond_gan.tf')\n",
"cond_gan.save_weights('cond_gan_weights.h5')\n",
"cond_gan.save_spec('cond_gan_spec.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.7 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "40d3a090f54c6569ab1632332b64b2c03c39dcf918b08424e98f38b5ae0af88f"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment