Last active
June 29, 2022 12:45
-
-
Save qweliant/313451c37bed4207c2589da0e6c609fb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"papermill": { | |
"duration": 0.046525, | |
"end_time": "2020-10-31T18:42:52.121984", | |
"exception": false, | |
"start_time": "2020-10-31T18:42:52.075459", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"source": [ | |
"# CycleGAN to generate Monet-style images" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"papermill": { | |
"duration": 0.042563, | |
"end_time": "2020-10-31T18:42:52.208643", | |
"exception": false, | |
"start_time": "2020-10-31T18:42:52.166080", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"source": [ | |
"This notebook is inspired by Amy Jang's tutorial notebook." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", | |
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", | |
"papermill": { | |
"duration": 2.85108, | |
"end_time": "2020-10-31T18:42:55.103424", | |
"exception": false, | |
"start_time": "2020-10-31T18:42:52.252344", | |
"status": "completed" | |
}, | |
"tags": [ | |
"outputPrepend" | |
] | |
}, | |
"outputs": [], | |
"source": [ | |
"# This Python 3 environment comes with many helpful analytics libraries installed\n", | |
"# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n", | |
"# For example, here's several helpful packages to load\n", | |
"\n", | |
"import numpy as np # linear algebra\n", | |
"import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", | |
"\n", | |
"# Input data files are available in the read-only \"../input/\" directory\n", | |
"# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n", | |
"count = 0\n", | |
"import os\n", | |
"for dirname, _, filenames in os.walk('.'):\n", | |
" for filename in filenames:\n", | |
" count+=1\n", | |
" print(os.path.join(dirname, filename))\n", | |
"print(count)\n", | |
"\n", | |
"# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n", | |
"# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0", | |
"_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a", | |
"papermill": { | |
"duration": 8.628127, | |
"end_time": "2020-10-31T18:43:03.803794", | |
"exception": false, | |
"start_time": "2020-10-31T18:42:55.175667", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import math\n", | |
"import random\n", | |
"import matplotlib.pyplot as plt\n", | |
"import cv2\n", | |
"import tensorflow as tf\n", | |
"from tensorflow import keras\n", | |
"from tensorflow.keras import layers\n", | |
"import tensorflow_addons as tfa\n", | |
"import tensorflow_datasets as tfds\n", | |
"import matplotlib.pyplot as plt\n", | |
"import seaborn as sns\n", | |
"import tensorflow.keras.backend as K\n", | |
"import os, random, json, PIL, shutil, re, imageio, glob\n", | |
"from tensorflow.keras import Model, losses, optimizers\n", | |
"from tensorflow.keras.callbacks import Callback\n", | |
"# from kaggle_datasets import KaggleDatasets" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 5.925849, | |
"end_time": "2020-10-31T18:43:09.790327", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:03.864478", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# Configuring TPU\n", | |
"\n", | |
"try:\n", | |
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver()\n", | |
" print(f'Running on TPU {tpu.master()}')\n", | |
"except ValueError:\n", | |
" tpu = None\n", | |
"\n", | |
"if tpu:\n", | |
" tf.config.experimental_connect_to_cluster(tpu)\n", | |
" tf.tpu.experimental.initialize_tpu_system(tpu)\n", | |
" strategy = tf.distribute.experimental.TPUStrategy(tpu)\n", | |
"else:\n", | |
" strategy = tf.distribute.get_strategy()\n", | |
"\n", | |
"\n", | |
"REPLICAS = strategy.num_replicas_in_sync\n", | |
"AUTO = tf.data.experimental.AUTOTUNE\n", | |
"print(f'REPLICAS: {REPLICAS}')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.069045, | |
"end_time": "2020-10-31T18:43:09.919831", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:09.850786", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"BASE_PATH = '.\\data\\input'\n", | |
"MONET_PATH = os.path.join(BASE_PATH, 'monet_jpg')\n", | |
"PHOTO_PATH = os.path.join(BASE_PATH, 'photo_jpg')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 34.419907, | |
"end_time": "2020-10-31T18:43:44.401071", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:09.981164", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"def show_folder_info(path):\n", | |
" d_image_sizes = {}\n", | |
" for image_name in os.listdir(path):\n", | |
" image = cv2.imread(os.path.join(path, image_name))\n", | |
" d_image_sizes[image.shape] = d_image_sizes.get(image.shape, 0) + 1\n", | |
" \n", | |
" for size, count in d_image_sizes.items():\n", | |
" print(f'shape: {size}\\tcount: {count}')\n", | |
"\n", | |
"\n", | |
"print('Monet images:')\n", | |
"show_folder_info(MONET_PATH)\n", | |
"\n", | |
"print('Photo images:')\n", | |
"show_folder_info(PHOTO_PATH)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"papermill": { | |
"duration": 0.06185, | |
"end_time": "2020-10-31T18:43:44.524609", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:44.462759", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"source": [ | |
" # Visualization" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"papermill": { | |
"duration": 0.062213, | |
"end_time": "2020-10-31T18:43:44.649069", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:44.586856", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"source": [ | |
"## (1) Batch visualization" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.07742, | |
"end_time": "2020-10-31T18:43:44.788923", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:44.711503", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"def batch_visualization(path, n_images, is_random=True, figsize=(16, 16)):\n", | |
" plt.figure(figsize=figsize)\n", | |
" \n", | |
" w = int(n_images ** .5)\n", | |
" h = math.ceil(n_images / w)\n", | |
" \n", | |
" all_names = os.listdir(path)\n", | |
" \n", | |
" image_names = all_names[:n_images]\n", | |
" if is_random:\n", | |
" image_names = random.sample(all_names, n_images)\n", | |
" \n", | |
" for ind, image_name in enumerate(image_names):\n", | |
" img = cv2.imread(os.path.join(path, image_name))\n", | |
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) \n", | |
" plt.subplot(h, w, ind + 1)\n", | |
" plt.imshow(img)\n", | |
" plt.axis('off')\n", | |
" \n", | |
" plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 1.814201, | |
"end_time": "2020-10-31T18:43:46.666580", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:44.852379", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"batch_visualization(MONET_PATH, 12, is_random=True, figsize=(23, 23))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 1.536706, | |
"end_time": "2020-10-31T18:43:48.323517", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:46.786811", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"batch_visualization(PHOTO_PATH, 12, is_random=True, figsize=(23, 23))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"papermill": { | |
"duration": 0.165469, | |
"end_time": "2020-10-31T18:43:48.663620", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:48.498151", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"source": [ | |
"## (2) Colour histograms" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.182251, | |
"end_time": "2020-10-31T18:43:49.010644", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:48.828393", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# Plotting colour histograms for Monet paintings\n", | |
"def color_hist_visualization(image_path, figsize=(16, 4)):\n", | |
" plt.figure(figsize=figsize)\n", | |
" colors = ['red', 'green', 'blue']\n", | |
" \n", | |
" img = cv2.imread(image_path)\n", | |
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) \n", | |
" plt.subplot(1, 4, 1)\n", | |
" plt.imshow(img)\n", | |
" plt.axis('off')\n", | |
" \n", | |
" for i in range(len(colors)):\n", | |
" plt.subplot(1, 4, i + 2)\n", | |
" plt.hist(\n", | |
" img[:, :, i].reshape(-1),\n", | |
" bins=25,\n", | |
" alpha=0.5,\n", | |
" color=colors[i],\n", | |
" density=True\n", | |
" )\n", | |
" plt.xlim(0, 255)\n", | |
" plt.xticks([])\n", | |
" plt.yticks([])\n", | |
" plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.585163, | |
"end_time": "2020-10-31T18:43:49.765505", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:49.180342", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"img_path = '.\\data\\input'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.568353, | |
"end_time": "2020-10-31T18:43:50.511956", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:49.943603", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"img_path = \"./data/input/monet_jpg/f486c1655f.jpg\"\n", | |
"color_hist_visualization(img_path)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.558595, | |
"end_time": "2020-10-31T18:43:51.242454", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:50.683859", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"img_path = './data/input/monet_jpg/e510a74d3c.jpg'\n", | |
"color_hist_visualization(img_path)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"papermill": { | |
"duration": 0.168091, | |
"end_time": "2020-10-31T18:43:51.586551", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:51.418460", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"source": [ | |
"## (3) Individual channels visualization" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.184858, | |
"end_time": "2020-10-31T18:43:51.941841", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:51.756983", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# Individual channels visualization \n", | |
"def channels_visualization(image_path, figsize=(16, 4)):\n", | |
" plt.figure(figsize=figsize)\n", | |
" \n", | |
" img = cv2.imread(image_path)\n", | |
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) \n", | |
" plt.subplot(1, 4, 1)\n", | |
" plt.imshow(img)\n", | |
" plt.axis('off')\n", | |
" \n", | |
" for i in range(3):\n", | |
" plt.subplot(1, 4, i + 2)\n", | |
" tmp_img = np.full_like(img, 0)\n", | |
" tmp_img[:, :, i] = img[:, :, i]\n", | |
" plt.imshow(tmp_img)\n", | |
" plt.xlim(0, 255)\n", | |
" plt.xticks([])\n", | |
" plt.yticks([])\n", | |
" plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.870707, | |
"end_time": "2020-10-31T18:43:52.983038", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:52.112331", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"img_path = './data/input/monet_jpg/f486c1655f.jpg'\n", | |
"color_hist_visualization(img_path)\n", | |
"channels_visualization(img_path)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"papermill": { | |
"duration": 0.183109, | |
"end_time": "2020-10-31T18:43:53.353790", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:53.170681", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"source": [ | |
"## (4) Greyscale visualization" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.196752, | |
"end_time": "2020-10-31T18:43:53.731590", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:53.534838", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# Grayscale visualization\n", | |
"def grayscale_visualization(image_path, figsize=(8, 4)):\n", | |
" plt.figure(figsize=figsize)\n", | |
" \n", | |
" img = cv2.imread(image_path)\n", | |
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) \n", | |
" plt.subplot(1, 2, 1)\n", | |
" plt.imshow(img)\n", | |
" plt.axis('off')\n", | |
" \n", | |
" plt.subplot(1, 2, 2)\n", | |
" tmp_img = np.full_like(img, 0)\n", | |
" for i in range(3):\n", | |
" tmp_img[:, :, i] = img.mean(axis=-1)\n", | |
" plt.imshow(tmp_img)\n", | |
" plt.axis('off')\n", | |
" \n", | |
" \n", | |
" plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.375279, | |
"end_time": "2020-10-31T18:43:54.289694", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:53.914415", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"img_path = './data/input/monet_jpg/e291f8144f.jpg'\n", | |
"grayscale_visualization(img_path)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"papermill": { | |
"duration": 0.188143, | |
"end_time": "2020-10-31T18:43:54.672789", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:54.484646", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"source": [ | |
"# Load the datasets" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"papermill": { | |
"duration": 0.187474, | |
"end_time": "2020-10-31T18:43:55.808808", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:55.621334", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"source": [ | |
"### Load the TFRecord files" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.53228, | |
"end_time": "2020-10-31T18:43:56.535276", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:56.002996", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# Get the Google Cloud Storage path URI (GCS path) for Kaggle Datasets\n", | |
"GCS_PATH = BASE_PATH\n", | |
"\n", | |
"GCS_PATH" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.36483, | |
"end_time": "2020-10-31T18:43:57.089447", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:56.724617", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# Obtain two lists of files that match the given patterns specified in str()\n", | |
"def count_data_items(filenames):\n", | |
" n = [int(re.compile(r\"-([0-9]*)\\.\").search(filename).group(1)) for filename in filenames]\n", | |
" return np.sum(n)\n", | |
"\n", | |
"MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))\n", | |
"PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))\n", | |
"\n", | |
"n_monet_samples = count_data_items(MONET_FILENAMES)\n", | |
"n_photo_samples = count_data_items(PHOTO_FILENAMES)\n", | |
"\n", | |
"print('Number of Monet TFRecord Files:', len(MONET_FILENAMES))\n", | |
"print('Number of Photo TFRecord Files:', len(PHOTO_FILENAMES))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.207281, | |
"end_time": "2020-10-31T18:43:57.494014", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:57.286733", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"BUFFER_SIZE = 1000\n", | |
"BATCH_SIZE = 4\n", | |
"EPOCHS_NUM = 1\n", | |
"IMG_WIDTH = 256\n", | |
"IMG_HEIGHT = 256" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.471598, | |
"end_time": "2020-10-31T18:43:58.157376", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:57.685778", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"def decode_image(image):\n", | |
" # Decode a JPEG-encoded image to a uint8 tensor.\n", | |
" image = tf.image.decode_jpeg(image, channels=3)\n", | |
" \n", | |
" # Normalize the image to the range of the tanh activation function [-1, 1] for \n", | |
" # inputs to the generator and discriminator in GAN model \n", | |
" # (i.e. the pixel values are divided by (255/2) to form a value of in a range of [0, 2] and then subtract by 1\n", | |
" # to result into a range of [-1, 1])\n", | |
" image = (tf.cast(image, tf.float32) / 127.5) - 1 \n", | |
" \n", | |
" # Reshape the tensor using (256, 256, 3) where 3 is number of channels: Red, Green, and Blue \n", | |
" image = tf.reshape(image, [IMG_HEIGHT, IMG_WIDTH, 3]) \n", | |
" return image\n", | |
"\n", | |
"def read_tfrecord(example):\n", | |
" # Define TFRecord format \n", | |
" tfrecord_format = {\n", | |
" \"image_name\": tf.io.FixedLenFeature([], tf.string),\n", | |
" \"image\": tf.io.FixedLenFeature([], tf.string),\n", | |
" \"target\": tf.io.FixedLenFeature([], tf.string)\n", | |
" }\n", | |
" # Parse a single example\n", | |
" example = tf.io.parse_single_example(example, tfrecord_format) \n", | |
" # Decode a JPEG image to a uint8 tensor by calling decode_image()\n", | |
" image = decode_image(example['image']) \n", | |
" \n", | |
" return image # Return an image tensor" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.209108, | |
"end_time": "2020-10-31T18:43:58.556829", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:58.347721", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"def data_augment(image):\n", | |
" p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)\n", | |
" p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)\n", | |
" p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)\n", | |
" \n", | |
" # Apply jitter\n", | |
" if p_crop > .5:\n", | |
" image = tf.image.resize(image, [286, 286])\n", | |
" image = tf.image.random_crop(image, size=[256, 256, 3])\n", | |
" if p_crop > .9:\n", | |
" image = tf.image.resize(image, [300, 300])\n", | |
" image = tf.image.random_crop(image, size=[256, 256, 3])\n", | |
" \n", | |
" # Random rotation\n", | |
" if p_rotate > .9:\n", | |
" image = tf.image.rot90(image, k=3) # rotate 270º\n", | |
" elif p_rotate > .7:\n", | |
" image = tf.image.rot90(image, k=2) # rotate 180º\n", | |
" elif p_rotate > .5:\n", | |
" image = tf.image.rot90(image, k=1) # rotate 90º\n", | |
" \n", | |
" # Random mirroring\n", | |
" if p_spatial > .6:\n", | |
" image = tf.image.random_flip_left_right(image)\n", | |
" image = tf.image.random_flip_up_down(image)\n", | |
" if p_spatial > .9:\n", | |
" image = tf.image.transpose(image)\n", | |
" \n", | |
" return image" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.20462, | |
"end_time": "2020-10-31T18:43:58.951841", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:58.747221", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# Set it to tf.data.experimental.AUTOTUNE which will prompt \n", | |
"# the tf.data runtime to tune the value dynamically at runtime.\n", | |
"AUTOTUNE = tf.data.experimental.AUTOTUNE \n", | |
"\n", | |
"def load_dataset(filenames, labeled=True, ordered=False):\n", | |
" dataset = tf.data.TFRecordDataset(filenames)\n", | |
" # map a dataset with a mapping function read_tfrecord and \n", | |
" # Number of parallel calls is set to AUTOTUNE constant previously defined\n", | |
" dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)\n", | |
" return dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.420785, | |
"end_time": "2020-10-31T18:43:59.563444", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:59.142659", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"BATCHSIZE = 1\n", | |
"monet_ds = load_dataset(MONET_FILENAMES, labeled=True).batch(BATCHSIZE, drop_remainder=True)\n", | |
"photo_ds = load_dataset(PHOTO_FILENAMES, labeled=True).batch(BATCHSIZE, drop_remainder=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.204843, | |
"end_time": "2020-10-31T18:43:59.960075", | |
"exception": false, | |
"start_time": "2020-10-31T18:43:59.755232", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"def get_gan_dataset(monet_files, photo_files, augment=None, repeat=True, shuffle=True, batch_size=1):\n", | |
"\n", | |
" monet_ds = load_dataset(monet_files)\n", | |
" photo_ds = load_dataset(photo_files)\n", | |
" \n", | |
" if augment:\n", | |
" monet_ds = monet_ds.map(augment, num_parallel_calls=AUTO)\n", | |
" photo_ds = photo_ds.map(augment, num_parallel_calls=AUTO)\n", | |
"\n", | |
" if repeat:\n", | |
" monet_ds = monet_ds.repeat()\n", | |
" photo_ds = photo_ds.repeat()\n", | |
" \n", | |
" if shuffle:\n", | |
" monet_ds = monet_ds.shuffle(2048)\n", | |
" photo_ds = photo_ds.shuffle(2048)\n", | |
" \n", | |
" monet_ds = monet_ds.batch(batch_size, drop_remainder=True)\n", | |
" photo_ds = photo_ds.batch(batch_size, drop_remainder=True)\n", | |
" monet_ds = monet_ds.cache()\n", | |
" photo_ds = photo_ds.cache()\n", | |
" monet_ds = monet_ds.prefetch(AUTO)\n", | |
" photo_ds = photo_ds.prefetch(AUTO)\n", | |
" \n", | |
" gan_ds = tf.data.Dataset.zip((monet_ds, photo_ds))\n", | |
" \n", | |
" return gan_ds" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 1.425059, | |
"end_time": "2020-10-31T18:44:01.574670", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:00.149611", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"full_dataset = get_gan_dataset(MONET_FILENAMES, PHOTO_FILENAMES, augment=data_augment, repeat=True, shuffle=True, batch_size=BATCH_SIZE)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 7.591707, | |
"end_time": "2020-10-31T18:44:09.358003", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:01.766296", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"example_monet , example_photo = next(iter(full_dataset))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.201931, | |
"end_time": "2020-10-31T18:44:09.749754", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:09.547823", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"def view_image(ds, nrows=1, ncols=5):\n", | |
" ds_iter = iter(ds)\n", | |
" # image = next(iter(ds)) # extract 1 from the dataset\n", | |
" # image = image.numpy() # convert the image tensor to NumPy ndarrays.\n", | |
"\n", | |
" fig = plt.figure(figsize=(25, nrows * 5.05 )) # figsize with Width, Height\n", | |
" \n", | |
" # loop thru all the images (number of rows * number of columns)\n", | |
" for i in range(ncols * nrows):\n", | |
" image = next(ds_iter)\n", | |
" image = image.numpy()\n", | |
" ax = fig.add_subplot(nrows, ncols, i+1, xticks=[], yticks=[])\n", | |
" ax.imshow(image[0] * 0.5 + .5) # rescale the data in [0, 1] for display" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 1.321906, | |
"end_time": "2020-10-31T18:44:11.258934", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:09.937028", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"view_image(monet_ds,2, 5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 1.370989, | |
"end_time": "2020-10-31T18:44:12.868971", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:11.497982", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"view_image(photo_ds,2, 5)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"papermill": { | |
"duration": 0.274559, | |
"end_time": "2020-10-31T18:44:13.408351", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:13.133792", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"source": [ | |
"## Build the generator" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.279973, | |
"end_time": "2020-10-31T18:44:13.956723", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:13.676750", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"OUTPUT_CHANNELS = 3\n", | |
"\n", | |
"def downsample(filters, size, apply_instancenorm=True):\n", | |
" initializer = tf.random_normal_initializer(0., 0.02)\n", | |
" gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)\n", | |
"\n", | |
" result = keras.Sequential()\n", | |
" result.add(layers.Conv2D(filters, size, strides=2, padding='same',\n", | |
" kernel_initializer=initializer, use_bias=False))\n", | |
"\n", | |
" if apply_instancenorm:\n", | |
" result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))\n", | |
"\n", | |
" result.add(layers.LeakyReLU())\n", | |
"\n", | |
" return result" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.281864, | |
"end_time": "2020-10-31T18:44:14.502105", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:14.220241", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"def upsample(filters, size, apply_dropout=False):\n", | |
" initializer = tf.random_normal_initializer(0., 0.02)\n", | |
" gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)\n", | |
"\n", | |
" result = keras.Sequential()\n", | |
" result.add(layers.Conv2DTranspose(filters, size, strides=2,\n", | |
" padding='same',\n", | |
" kernel_initializer=initializer,\n", | |
" use_bias=False))\n", | |
"\n", | |
" result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))\n", | |
"\n", | |
" if apply_dropout:\n", | |
" result.add(layers.Dropout(0.5))\n", | |
"\n", | |
" result.add(layers.ReLU())\n", | |
"\n", | |
" return result" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.359497, | |
"end_time": "2020-10-31T18:44:15.125754", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:14.766257", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"def Generator():\n", | |
" inputs = layers.Input(shape=[256,256,3])\n", | |
"\n", | |
" # bs = batch size\n", | |
" down_stack = [\n", | |
" downsample(64, 4, apply_instancenorm=False), # (bs, 128, 128, 64)\n", | |
" downsample(128, 4), # (bs, 64, 64, 128)\n", | |
" downsample(256, 4), # (bs, 32, 32, 256)\n", | |
" downsample(512, 4), # (bs, 16, 16, 512)\n", | |
" downsample(512, 4), # (bs, 8, 8, 512)\n", | |
" downsample(512, 4), # (bs, 4, 4, 512)\n", | |
" downsample(512, 4), # (bs, 2, 2, 512)\n", | |
" downsample(512, 4), # (bs, 1, 1, 512)\n", | |
" ]\n", | |
"\n", | |
" up_stack = [\n", | |
" upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)\n", | |
" upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)\n", | |
" upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)\n", | |
" upsample(512, 4), # (bs, 16, 16, 1024)\n", | |
" upsample(256, 4), # (bs, 32, 32, 512)\n", | |
" upsample(128, 4), # (bs, 64, 64, 256)\n", | |
" upsample(64, 4), # (bs, 128, 128, 128)\n", | |
" ]\n", | |
"\n", | |
" initializer = tf.random_normal_initializer(0., 0.02)\n", | |
" last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,\n", | |
" strides=2,\n", | |
" padding='same',\n", | |
" kernel_initializer=initializer,\n", | |
" activation='tanh') # (bs, 256, 256, 3)\n", | |
"\n", | |
" x = inputs\n", | |
"\n", | |
" # Downsampling through the model\n", | |
" skips = []\n", | |
" for down in down_stack:\n", | |
" x = down(x)\n", | |
" skips.append(x)\n", | |
"\n", | |
" skips = reversed(skips[:-1])\n", | |
"\n", | |
" # Upsampling and establishing the skip connections\n", | |
" for up, skip in zip(up_stack, skips):\n", | |
" x = up(x)\n", | |
" x = layers.Concatenate()([x, skip])\n", | |
"\n", | |
" x = last(x)\n", | |
"\n", | |
" return keras.Model(inputs=inputs, outputs=x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"papermill": { | |
"duration": 0.258077, | |
"end_time": "2020-10-31T18:44:15.651407", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:15.393330", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"source": [ | |
"## Build the discriminator" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.282605, | |
"end_time": "2020-10-31T18:44:16.194729", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:15.912124", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"def Discriminator():\n", | |
" initializer = tf.random_normal_initializer(0., 0.02)\n", | |
" gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)\n", | |
"\n", | |
" inp = layers.Input(shape=[256, 256, 3], name='input_image')\n", | |
"\n", | |
" x = inp\n", | |
"\n", | |
" down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)\n", | |
" down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)\n", | |
" down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)\n", | |
"\n", | |
" zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)\n", | |
" conv = layers.Conv2D(512, 4, strides=1,\n", | |
" kernel_initializer=initializer,\n", | |
" use_bias=False)(zero_pad1) # (bs, 31, 31, 512)\n", | |
"\n", | |
" norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)\n", | |
"\n", | |
" leaky_relu = layers.LeakyReLU()(norm1)\n", | |
"\n", | |
" zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)\n", | |
"\n", | |
" last = layers.Conv2D(1, 4, strides=1,\n", | |
" kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)\n", | |
"\n", | |
" return tf.keras.Model(inputs=inp, outputs=last)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 12.419442, | |
"end_time": "2020-10-31T18:44:28.877438", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:16.457996", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"with strategy.scope():\n", | |
" monet_generator = Generator() # transforms photos to Monet-esque paintings\n", | |
" photo_generator = Generator() # transforms Monet paintings to be more like photos\n", | |
"\n", | |
" monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings\n", | |
" photo_discriminator = Discriminator() # differentiates real photos and generated photos" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 1.133666, | |
"end_time": "2020-10-31T18:44:30.301047", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:29.167381", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"to_monet = monet_generator(example_photo)\n", | |
"\n", | |
"plt.subplot(1, 2, 1)\n", | |
"plt.title(\"Original Photo\")\n", | |
"plt.imshow(example_photo[0] * 0.5 + 0.5)\n", | |
"\n", | |
"plt.subplot(1, 2, 2)\n", | |
"plt.title(\"Monet-esque Photo\")\n", | |
"plt.imshow(to_monet[0] * 0.5 + 0.5)\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"papermill": { | |
"duration": 0.27006, | |
"end_time": "2020-10-31T18:44:30.905610", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:30.635550", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"source": [ | |
"## Build the CycleGAN model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.310948, | |
"end_time": "2020-10-31T18:44:31.493564", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:31.182616", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"class CycleGan(keras.Model):\n", | |
" def __init__(\n", | |
" self,\n", | |
" monet_generator,\n", | |
" photo_generator,\n", | |
" monet_discriminator,\n", | |
" photo_discriminator,\n", | |
" lambda_cycle=10,\n", | |
" ):\n", | |
" super(CycleGan, self).__init__()\n", | |
" self.m_gen = monet_generator\n", | |
" self.p_gen = photo_generator\n", | |
" self.m_disc = monet_discriminator\n", | |
" self.p_disc = photo_discriminator\n", | |
" self.lambda_cycle = lambda_cycle\n", | |
" \n", | |
" def compile(\n", | |
" self,\n", | |
" m_gen_optimizer,\n", | |
" p_gen_optimizer,\n", | |
" m_disc_optimizer,\n", | |
" p_disc_optimizer,\n", | |
" gen_loss_fn,\n", | |
" disc_loss_fn,\n", | |
" cycle_loss_fn,\n", | |
" identity_loss_fn\n", | |
" ):\n", | |
" super(CycleGan, self).compile()\n", | |
" self.m_gen_optimizer = m_gen_optimizer\n", | |
" self.p_gen_optimizer = p_gen_optimizer\n", | |
" self.m_disc_optimizer = m_disc_optimizer\n", | |
" self.p_disc_optimizer = p_disc_optimizer\n", | |
" self.gen_loss_fn = gen_loss_fn\n", | |
" self.disc_loss_fn = disc_loss_fn\n", | |
" self.cycle_loss_fn = cycle_loss_fn\n", | |
" self.identity_loss_fn = identity_loss_fn\n", | |
" \n", | |
" def train_step(self, batch_data):\n", | |
" real_monet, real_photo = batch_data\n", | |
" \n", | |
" with tf.GradientTape(persistent=True) as tape:\n", | |
" # photo to monet back to photo\n", | |
" fake_monet = self.m_gen(real_photo, training=True)\n", | |
" cycled_photo = self.p_gen(fake_monet, training=True)\n", | |
"\n", | |
" # monet to photo back to monet\n", | |
" fake_photo = self.p_gen(real_monet, training=True)\n", | |
" cycled_monet = self.m_gen(fake_photo, training=True)\n", | |
"\n", | |
" # generating itself\n", | |
" same_monet = self.m_gen(real_monet, training=True)\n", | |
" same_photo = self.p_gen(real_photo, training=True)\n", | |
"\n", | |
" # discriminator used to check, inputing real images\n", | |
" disc_real_monet = self.m_disc(real_monet, training=True)\n", | |
" disc_real_photo = self.p_disc(real_photo, training=True)\n", | |
"\n", | |
" # discriminator used to check, inputing fake images\n", | |
" disc_fake_monet = self.m_disc(fake_monet, training=True)\n", | |
" disc_fake_photo = self.p_disc(fake_photo, training=True)\n", | |
"\n", | |
" # evaluates generator loss\n", | |
" monet_gen_loss = self.gen_loss_fn(disc_fake_monet)\n", | |
" photo_gen_loss = self.gen_loss_fn(disc_fake_photo)\n", | |
"\n", | |
" # evaluates total cycle consistency loss\n", | |
" total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)\n", | |
"\n", | |
" # evaluates total generator loss\n", | |
" total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)\n", | |
" total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)\n", | |
"\n", | |
" # evaluates discriminator loss\n", | |
" monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)\n", | |
" photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)\n", | |
"\n", | |
" # Calculate the gradients for generator and discriminator\n", | |
" monet_generator_gradients = tape.gradient(total_monet_gen_loss,\n", | |
" self.m_gen.trainable_variables)\n", | |
" photo_generator_gradients = tape.gradient(total_photo_gen_loss,\n", | |
" self.p_gen.trainable_variables)\n", | |
"\n", | |
" monet_discriminator_gradients = tape.gradient(monet_disc_loss,\n", | |
" self.m_disc.trainable_variables)\n", | |
" photo_discriminator_gradients = tape.gradient(photo_disc_loss,\n", | |
" self.p_disc.trainable_variables)\n", | |
"\n", | |
" # Apply the gradients to the optimizer\n", | |
" self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,\n", | |
" self.m_gen.trainable_variables))\n", | |
"\n", | |
" self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,\n", | |
" self.p_gen.trainable_variables))\n", | |
"\n", | |
" self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,\n", | |
" self.m_disc.trainable_variables))\n", | |
"\n", | |
" self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,\n", | |
" self.p_disc.trainable_variables))\n", | |
" \n", | |
" return {\n", | |
" \"monet_gen_loss\": total_monet_gen_loss,\n", | |
" \"photo_gen_loss\": total_photo_gen_loss,\n", | |
" \"monet_disc_loss\": monet_disc_loss,\n", | |
" \"photo_disc_loss\": photo_disc_loss\n", | |
" }" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"papermill": { | |
"duration": 0.266165, | |
"end_time": "2020-10-31T18:44:32.042221", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:31.776056", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"source": [ | |
"# Loss functions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.28944, | |
"end_time": "2020-10-31T18:44:32.605875", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:32.316435", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"with strategy.scope():\n", | |
" # Discriminator loss {0: fake, 1: real} (The discriminator loss outputs the average of the real and generated loss)\n", | |
" def discriminator_loss(real, generated):\n", | |
" real_loss = losses.BinaryCrossentropy(from_logits=True, reduction=losses.Reduction.NONE)(tf.ones_like(real), real)\n", | |
"\n", | |
" generated_loss = losses.BinaryCrossentropy(from_logits=True, reduction=losses.Reduction.NONE)(tf.zeros_like(generated), generated)\n", | |
"\n", | |
" total_disc_loss = real_loss + generated_loss\n", | |
"\n", | |
" return total_disc_loss * 0.5\n", | |
" \n", | |
" # Generator loss\n", | |
" def generator_loss(generated):\n", | |
" return losses.BinaryCrossentropy(from_logits=True, reduction=losses.Reduction.NONE)(tf.ones_like(generated), generated)\n", | |
" \n", | |
" \n", | |
" # Cycle consistency loss (measures if original photo and the twice transformed photo to be similar to one another)\n", | |
" with strategy.scope():\n", | |
" def calc_cycle_loss(real_image, cycled_image, LAMBDA):\n", | |
" loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))\n", | |
"\n", | |
" return LAMBDA * loss1\n", | |
"\n", | |
" # Identity loss (compares the image with its generator (i.e. photo with photo generator))\n", | |
" with strategy.scope():\n", | |
" def identity_loss(real_image, same_image, LAMBDA):\n", | |
" loss = tf.reduce_mean(tf.abs(real_image - same_image))\n", | |
" return LAMBDA * 0.5 * loss" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.277524, | |
"end_time": "2020-10-31T18:44:33.157169", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:32.879645", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"with strategy.scope():\n", | |
" monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n", | |
" photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n", | |
"\n", | |
" monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n", | |
" photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 0.360152, | |
"end_time": "2020-10-31T18:44:33.794927", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:33.434775", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"with strategy.scope():\n", | |
" cycle_gan_model = CycleGan(\n", | |
" monet_generator, photo_generator, \n", | |
" monet_discriminator, photo_discriminator\n", | |
" )\n", | |
"\n", | |
" cycle_gan_model.compile(\n", | |
" m_gen_optimizer = monet_generator_optimizer,\n", | |
" p_gen_optimizer = photo_generator_optimizer,\n", | |
" m_disc_optimizer = monet_discriminator_optimizer,\n", | |
" p_disc_optimizer = photo_discriminator_optimizer,\n", | |
" gen_loss_fn = generator_loss,\n", | |
" disc_loss_fn = discriminator_loss,\n", | |
" cycle_loss_fn = calc_cycle_loss,\n", | |
" identity_loss_fn = identity_loss\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 7499.281619, | |
"end_time": "2020-10-31T20:49:33.345104", | |
"exception": false, | |
"start_time": "2020-10-31T18:44:34.063485", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"cycle_gan_model.fit(\n", | |
" full_dataset,\n", | |
" epochs=EPOCHS_NUM,\n", | |
" steps_per_epoch=(max(n_monet_samples, n_photo_samples)//BATCH_SIZE),\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"papermill": { | |
"duration": 21.863544, | |
"end_time": "2020-10-31T20:50:17.606469", | |
"exception": false, | |
"start_time": "2020-10-31T20:49:55.742925", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"source": [ | |
"## Display generated photos" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 21.848811, | |
"end_time": "2020-10-31T20:51:01.877384", | |
"exception": false, | |
"start_time": "2020-10-31T20:50:40.028573", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"def display_generated_samples(ds, model, n_samples):\n", | |
" ds_iter = iter(ds)\n", | |
" for n_sample in range(n_samples):\n", | |
" example_sample = next(ds_iter)\n", | |
" generated_sample = model.predict(example_sample)\n", | |
" \n", | |
" plt.subplot(121)\n", | |
" plt.title(\"Input image\")\n", | |
" plt.imshow(example_sample[0] * 0.5 + 0.5)\n", | |
" plt.axis('off')\n", | |
"\n", | |
" plt.subplot(122)\n", | |
" plt.title(\"Generated image\")\n", | |
" plt.imshow(generated_sample[0] * 0.5 + 0.5)\n", | |
" plt.axis('off')\n", | |
" plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 32.517448, | |
"end_time": "2020-10-31T20:51:56.071666", | |
"exception": false, | |
"start_time": "2020-10-31T20:51:23.554218", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"display_generated_samples(load_dataset(PHOTO_FILENAMES).batch(1), monet_generator, 7)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"papermill": { | |
"duration": 22.357062, | |
"end_time": "2020-10-31T20:52:40.336541", | |
"exception": false, | |
"start_time": "2020-10-31T20:52:17.979479", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"source": [ | |
"## Predict and save the images" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 22.222689, | |
"end_time": "2020-10-31T20:53:24.734178", | |
"exception": false, | |
"start_time": "2020-10-31T20:53:02.511489", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# import PIL\n", | |
"# def predict_and_save(input_ds, generator_model, output_path):\n", | |
"# i = 1\n", | |
"# for img in input_ds:\n", | |
"# prediction = generator_model(img, training=False)[0].numpy() # make predition\n", | |
"# prediction = (prediction * 127.5 + 127.5).astype(np.uint8) # re-scale\n", | |
"# im = PIL.Image.fromarray(prediction)\n", | |
"# im.save(f'{output_path}{str(i)}.jpg')\n", | |
"# i += 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 2262.505584, | |
"end_time": "2020-10-31T21:31:29.466891", | |
"exception": false, | |
"start_time": "2020-10-31T20:53:46.961307", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# import os\n", | |
"# os.makedirs('../images/') # Create folder to save generated images\n", | |
"\n", | |
"# predict_and_save(load_dataset(PHOTO_FILENAMES).batch(1), monet_generator, '../images/')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"papermill": { | |
"duration": 26.821222, | |
"end_time": "2020-10-31T21:32:18.288346", | |
"exception": false, | |
"start_time": "2020-10-31T21:31:51.467124", | |
"status": "completed" | |
}, | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# import shutil\n", | |
"# shutil.make_archive('/kaggle/working/images/', 'zip', '../images')\n", | |
"\n", | |
"# print(f\"Number of generated samples: {len([name for name in os.listdir('../images/') if os.path.isfile(os.path.join('../images/', name))])}\")" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.2" | |
}, | |
"metadata": { | |
"interpreter": { | |
"hash": "fb48522066f4682e524069d6228dfc70e55d613a5dc94c6da58df1d85977e5e4" | |
} | |
}, | |
"papermill": { | |
"duration": 10193.306434, | |
"end_time": "2020-10-31T21:32:40.470231", | |
"environment_variables": {}, | |
"exception": null, | |
"input_path": "__notebook__.ipynb", | |
"output_path": "__notebook__.ipynb", | |
"parameters": {}, | |
"start_time": "2020-10-31T18:42:47.163797", | |
"version": "2.1.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment