Last active
June 20, 2018 17:34
-
-
Save artificialsoph/fbcd6f519e78d0d760efc6dc85c6910e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-06-20T17:10:56.184264Z", | |
"start_time": "2018-06-20T17:10:56.180289Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"from keras.datasets import mnist\n", | |
"from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, concatenate\n", | |
"from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D, Lambda\n", | |
"from keras.layers.advanced_activations import LeakyReLU\n", | |
"from keras.layers.convolutional import UpSampling2D, Conv2D\n", | |
"from keras.models import Sequential, Model\n", | |
"from keras.optimizers import Adam\n", | |
"from keras.utils import to_categorical\n", | |
"import keras.backend as K\n", | |
"\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-06-20T17:21:39.596837Z", | |
"start_time": "2018-06-20T17:21:39.572870Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"class INFOGAN():\n", | |
" def __init__(self):\n", | |
" \n", | |
" # SOPH: review these params and update accordingly\n", | |
" self.img_rows = 28\n", | |
" self.img_cols = 28\n", | |
" self.channels = 1\n", | |
" self.num_classes = 10\n", | |
" self.img_shape = (self.img_rows, self.img_cols, self.channels)\n", | |
" self.latent_dim = 72\n", | |
"\n", | |
"\n", | |
" optimizer = Adam(0.0002, 0.5)\n", | |
" losses = ['binary_crossentropy', self.mutual_info_loss]\n", | |
"\n", | |
" # Build and the discriminator and recognition network\n", | |
" self.discriminator, self.auxilliary = self.build_disk_and_q_net()\n", | |
"\n", | |
" self.discriminator.compile(loss=['binary_crossentropy'],\n", | |
" optimizer=optimizer,\n", | |
" metrics=['accuracy'])\n", | |
"\n", | |
" # Build and compile the recognition network Q\n", | |
" self.auxilliary.compile(loss=[self.mutual_info_loss],\n", | |
" optimizer=optimizer,\n", | |
" metrics=['accuracy'])\n", | |
"\n", | |
" # Build the generator\n", | |
" self.generator = self.build_generator()\n", | |
"\n", | |
" # The generator takes noise and the target label as input\n", | |
" # and generates the corresponding digit of that label\n", | |
" gen_input = Input(shape=(self.latent_dim,))\n", | |
" img = self.generator(gen_input)\n", | |
"\n", | |
" # For the combined model we will only train the generator\n", | |
" self.discriminator.trainable = False\n", | |
"\n", | |
" # The discriminator takes generated image as input and determines validity\n", | |
" valid = self.discriminator(img)\n", | |
" # The recognition network produces the label\n", | |
" target_label = self.auxilliary(img)\n", | |
"\n", | |
" # The combined model (stacked generator and discriminator)\n", | |
" self.combined = Model(gen_input, [valid, target_label])\n", | |
" self.combined.compile(loss=losses,\n", | |
" optimizer=optimizer)\n", | |
"\n", | |
"\n", | |
" def build_generator(self):\n", | |
"\n", | |
" \n", | |
" #SOPH: there's a chance that the generator doesn't generate images with dimensionality exactly like your \n", | |
" # training images. If that's the case, some annoying debugging might be in order. I'd recommend\n", | |
" # tweaking the layers in the generator seperately until you get the dimensionality right and then moving that code\n", | |
" # here after. Otherwise, consider scaling all of your training images to be square\n", | |
" model = Sequential()\n", | |
"\n", | |
" model.add(Dense(128 * 7 * 7, activation=\"relu\", input_dim=self.latent_dim))\n", | |
" model.add(Reshape((7, 7, 128)))\n", | |
" model.add(BatchNormalization(momentum=0.8))\n", | |
" model.add(UpSampling2D())\n", | |
" model.add(Conv2D(128, kernel_size=3, padding=\"same\"))\n", | |
" model.add(Activation(\"relu\"))\n", | |
" model.add(BatchNormalization(momentum=0.8))\n", | |
" model.add(UpSampling2D())\n", | |
" model.add(Conv2D(64, kernel_size=3, padding=\"same\"))\n", | |
" model.add(Activation(\"relu\"))\n", | |
" model.add(BatchNormalization(momentum=0.8))\n", | |
" model.add(Conv2D(self.channels, kernel_size=3, padding='same'))\n", | |
" model.add(Activation(\"tanh\"))\n", | |
"\n", | |
" gen_input = Input(shape=(self.latent_dim,))\n", | |
" img = model(gen_input)\n", | |
"\n", | |
" print(\"gen summary:\")\n", | |
" model.summary()\n", | |
"\n", | |
" return Model(gen_input, img)\n", | |
"\n", | |
"\n", | |
" def build_disk_and_q_net(self):\n", | |
"\n", | |
" img = Input(shape=self.img_shape)\n", | |
"\n", | |
" # Shared layers between discriminator and recognition network\n", | |
" model = Sequential()\n", | |
" model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.img_shape, padding=\"same\"))\n", | |
" model.add(LeakyReLU(alpha=0.2))\n", | |
" model.add(Dropout(0.25))\n", | |
" model.add(Conv2D(128, kernel_size=3, strides=2, padding=\"same\"))\n", | |
" model.add(ZeroPadding2D(padding=((0,1),(0,1))))\n", | |
" model.add(LeakyReLU(alpha=0.2))\n", | |
" model.add(Dropout(0.25))\n", | |
" model.add(BatchNormalization(momentum=0.8))\n", | |
" model.add(Conv2D(256, kernel_size=3, strides=2, padding=\"same\"))\n", | |
" model.add(LeakyReLU(alpha=0.2))\n", | |
" model.add(Dropout(0.25))\n", | |
" model.add(BatchNormalization(momentum=0.8))\n", | |
" model.add(Conv2D(512, kernel_size=3, strides=2, padding=\"same\"))\n", | |
" model.add(LeakyReLU(alpha=0.2))\n", | |
" model.add(Dropout(0.25))\n", | |
" model.add(BatchNormalization(momentum=0.8))\n", | |
" model.add(Flatten())\n", | |
"\n", | |
" img_embedding = model(img)\n", | |
"\n", | |
" # Discriminator\n", | |
" validity = Dense(1, activation='sigmoid')(img_embedding)\n", | |
"\n", | |
" # Recognition\n", | |
" q_net = Dense(128, activation='relu')(img_embedding)\n", | |
" # SOPH: this activation should reflect that you're doing multi-label with binary labels. \n", | |
" # So, rather than softmax, what should you have?\n", | |
" label = Dense(self.num_classes, activation='softmax')(q_net)\n", | |
" \n", | |
" disk = Model(img, validity)\n", | |
" \n", | |
" print(\"disk summary:\")\n", | |
" disk.summary()\n", | |
" \n", | |
" q = Model(img, label)\n", | |
" \n", | |
" print(\"q summary\")\n", | |
" q.summary()\n", | |
"\n", | |
" # Return discriminator and recognition network\n", | |
" return disk, q\n", | |
"\n", | |
"\n", | |
" def mutual_info_loss(self, c, c_given_x):\n", | |
" \"\"\"The mutual information metric we aim to minimize\"\"\"\n", | |
" eps = 1e-8\n", | |
" conditional_entropy = K.mean(- K.sum(K.log(c_given_x + eps) * c, axis=1))\n", | |
" entropy = K.mean(- K.sum(K.log(c + eps) * c, axis=1))\n", | |
"\n", | |
" return conditional_entropy + entropy\n", | |
"\n", | |
" def sample_generator_input(self, batch_size):\n", | |
" # Generator inputs\n", | |
" sampled_noise = np.random.normal(0, 1, (batch_size, 62))\n", | |
" \n", | |
" \n", | |
" # SOPH: this code generates random combinations of your labels so that it can generate randomized images\n", | |
" # the following code should be removed and replaced with something that generates random values for your 40\n", | |
" # labels\n", | |
" sampled_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)\n", | |
" sampled_labels = to_categorical(sampled_labels, num_classes=self.num_classes)\n", | |
"\n", | |
" return sampled_noise, sampled_labels\n", | |
"\n", | |
" def train(self, epochs, batch_size=128, sample_interval=50):\n", | |
"\n", | |
" # SOPH: replace the following code block with code that loads celebA. \n", | |
" # x_train should be [n_examples x pixels x pixels]\n", | |
" # y_train should be [n_examples x n_labels]\n", | |
" # Load the dataset\n", | |
" (X_train, y_train), (_, _) = mnist.load_data()\n", | |
" \n", | |
" # Rescale -1 to 1\n", | |
" X_train = (X_train.astype(np.float32) - 127.5) / 127.5\n", | |
" X_train = np.expand_dims(X_train, axis=3)\n", | |
" y_train = y_train.reshape(-1, 1)\n", | |
" # SOPH: this is the end of the loading code. You may have to replace/remove any of these lines\n", | |
"\n", | |
" # Adversarial ground truths\n", | |
" valid = np.ones((batch_size, 1))\n", | |
" fake = np.zeros((batch_size, 1))\n", | |
"\n", | |
" for epoch in range(epochs):\n", | |
"\n", | |
" # ---------------------\n", | |
" # Train Discriminator\n", | |
" # ---------------------\n", | |
"\n", | |
" # Select a random half batch of images\n", | |
" idx = np.random.randint(0, X_train.shape[0], batch_size)\n", | |
" imgs = X_train[idx]\n", | |
"\n", | |
" # Sample noise and categorical labels\n", | |
" sampled_noise, sampled_labels = self.sample_generator_input(batch_size)\n", | |
" gen_input = np.concatenate((sampled_noise, sampled_labels), axis=1)\n", | |
"\n", | |
" # Generate a half batch of new images\n", | |
" gen_imgs = self.generator.predict(gen_input)\n", | |
"\n", | |
" # Train on real and generated data\n", | |
" d_loss_real = self.discriminator.train_on_batch(imgs, valid)\n", | |
" d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)\n", | |
"\n", | |
" # Avg. loss\n", | |
" d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)\n", | |
"\n", | |
" # ---------------------\n", | |
" # Train Generator and Q-network\n", | |
" # ---------------------\n", | |
"\n", | |
" g_loss = self.combined.train_on_batch(gen_input, [valid, sampled_labels])\n", | |
"\n", | |
" # Plot the progress\n", | |
"\n", | |
" # If at save interval => save generated image samples\n", | |
" if epoch % sample_interval == 0:\n", | |
" print (\"%d [D loss: %.2f, acc.: %.2f%%] [Q loss: %.2f] [G loss: %.2f]\" % (epoch, d_loss[0], 100*d_loss[1], g_loss[1], g_loss[2]))\n", | |
" self.sample_images(epoch)\n", | |
"\n", | |
" def sample_images(self, epoch):\n", | |
" # SOPH: this'll be hard to adjust so consider just removing it or greatly simplifying it to start out.\n", | |
" # this generates those image matrices.\n", | |
" r, c = 10, 10\n", | |
"\n", | |
" fig, axs = plt.subplots(r, c)\n", | |
" for i in range(c):\n", | |
" sampled_noise, _ = self.sample_generator_input(c)\n", | |
" label = to_categorical(np.full(fill_value=i, shape=(r,1)), num_classes=self.num_classes)\n", | |
" gen_input = np.concatenate((sampled_noise, label), axis=1)\n", | |
" gen_imgs = self.generator.predict(gen_input)\n", | |
" gen_imgs = 0.5 * gen_imgs + 0.5\n", | |
" for j in range(r):\n", | |
" axs[j,i].imshow(gen_imgs[j,:,:,0], cmap='gray')\n", | |
" axs[j,i].axis('off')\n", | |
" fig.savefig(\"images/%d.png\" % epoch)\n", | |
" plt.close()\n", | |
"\n", | |
" def save_model(self):\n", | |
"\n", | |
" def save(model, model_name):\n", | |
" model_path = \"saved_model/%s.json\" % model_name\n", | |
" weights_path = \"saved_model/%s_weights.hdf5\" % model_name\n", | |
" options = {\"file_arch\": model_path,\n", | |
" \"file_weight\": weights_path}\n", | |
" json_string = model.to_json()\n", | |
" open(options['file_arch'], 'w').write(json_string)\n", | |
" model.save_weights(options['file_weight'])\n", | |
"\n", | |
" save(self.generator, \"generator\")\n", | |
" save(self.discriminator, \"discriminator\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-06-20T17:32:02.063801Z", | |
"start_time": "2018-06-20T17:32:02.060492Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"import keras" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-06-20T17:32:53.104698Z", | |
"start_time": "2018-06-20T17:32:51.882258Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"gen summary:\n", | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"dense_21 (Dense) (None, 18304) 750464 \n", | |
"_________________________________________________________________\n", | |
"reshape_9 (Reshape) (None, 13, 11, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"batch_normalization_39 (Batc (None, 13, 11, 128) 512 \n", | |
"_________________________________________________________________\n", | |
"up_sampling2d_19 (UpSampling (None, 26, 22, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"conv2d_43 (Conv2D) (None, 26, 22, 128) 147584 \n", | |
"_________________________________________________________________\n", | |
"activation_27 (Activation) (None, 26, 22, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"batch_normalization_40 (Batc (None, 26, 22, 128) 512 \n", | |
"_________________________________________________________________\n", | |
"up_sampling2d_20 (UpSampling (None, 52, 44, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"conv2d_44 (Conv2D) (None, 52, 44, 128) 147584 \n", | |
"_________________________________________________________________\n", | |
"activation_28 (Activation) (None, 52, 44, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"batch_normalization_41 (Batc (None, 52, 44, 128) 512 \n", | |
"_________________________________________________________________\n", | |
"up_sampling2d_21 (UpSampling (None, 104, 88, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"zero_padding2d_5 (ZeroPaddin (None, 109, 89, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"conv2d_45 (Conv2D) (None, 109, 89, 128) 147584 \n", | |
"_________________________________________________________________\n", | |
"activation_29 (Activation) (None, 109, 89, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"batch_normalization_42 (Batc (None, 109, 89, 128) 512 \n", | |
"_________________________________________________________________\n", | |
"up_sampling2d_22 (UpSampling (None, 218, 178, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"conv2d_46 (Conv2D) (None, 218, 178, 64) 73792 \n", | |
"_________________________________________________________________\n", | |
"activation_30 (Activation) (None, 218, 178, 64) 0 \n", | |
"_________________________________________________________________\n", | |
"batch_normalization_43 (Batc (None, 218, 178, 64) 256 \n", | |
"_________________________________________________________________\n", | |
"conv2d_47 (Conv2D) (None, 218, 178, 3) 1731 \n", | |
"_________________________________________________________________\n", | |
"activation_31 (Activation) (None, 218, 178, 3) 0 \n", | |
"=================================================================\n", | |
"Total params: 1,271,043\n", | |
"Trainable params: 1,269,891\n", | |
"Non-trainable params: 1,152\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"# input_dims = 54, 44\n", | |
"\n", | |
"model = Sequential()\n", | |
"\n", | |
"model.add(Dense(128 * 13 * 11, activation=\"relu\", input_dim=40))\n", | |
"model.add(Reshape((13, 11, 128)))\n", | |
"model.add(BatchNormalization(momentum=0.8))\n", | |
"\n", | |
"model.add(UpSampling2D())\n", | |
"model.add(Conv2D(128, kernel_size=3, padding=\"same\"))\n", | |
"model.add(Activation(\"relu\"))\n", | |
"model.add(BatchNormalization(momentum=0.8))\n", | |
"\n", | |
"model.add(UpSampling2D())\n", | |
"model.add(Conv2D(128, kernel_size=3, padding=\"same\"))\n", | |
"model.add(Activation(\"relu\"))\n", | |
"model.add(BatchNormalization(momentum=0.8))\n", | |
"\n", | |
"model.add(UpSampling2D())\n", | |
"model.add(keras.layers.ZeroPadding2D(((3,2), (0,1))))\n", | |
"\n", | |
"model.add(Conv2D(128, kernel_size=3, padding=\"same\"))\n", | |
"model.add(Activation(\"relu\"))\n", | |
"model.add(BatchNormalization(momentum=0.8))\n", | |
"\n", | |
"model.add(UpSampling2D())\n", | |
"model.add(Conv2D(64, kernel_size=3, padding=\"same\"))\n", | |
"model.add(Activation(\"relu\"))\n", | |
"model.add(BatchNormalization(momentum=0.8))\n", | |
"model.add(Conv2D(3, kernel_size=3, padding='same'))\n", | |
"model.add(Activation(\"tanh\"))\n", | |
"\n", | |
"gen_input = Input(shape=(40,))\n", | |
"img = model(gen_input)\n", | |
"\n", | |
"print(\"gen summary:\")\n", | |
"model.summary()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-06-20T17:31:27.572321Z", | |
"start_time": "2018-06-20T17:31:27.568083Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(109.0, 89.0)" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"218/2, 178/2\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-06-20T17:21:56.125215Z", | |
"start_time": "2018-06-20T17:21:46.546713Z" | |
}, | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"disk summary:\n", | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"input_10 (InputLayer) (None, 28, 28, 1) 0 \n", | |
"_________________________________________________________________\n", | |
"sequential_7 (Sequential) (None, 2048) 1553408 \n", | |
"_________________________________________________________________\n", | |
"dense_13 (Dense) (None, 1) 2049 \n", | |
"=================================================================\n", | |
"Total params: 1,555,457\n", | |
"Trainable params: 1,553,665\n", | |
"Non-trainable params: 1,792\n", | |
"_________________________________________________________________\n", | |
"q summary\n", | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"input_10 (InputLayer) (None, 28, 28, 1) 0 \n", | |
"_________________________________________________________________\n", | |
"sequential_7 (Sequential) (None, 2048) 1553408 \n", | |
"_________________________________________________________________\n", | |
"dense_14 (Dense) (None, 128) 262272 \n", | |
"_________________________________________________________________\n", | |
"dense_15 (Dense) (None, 10) 1290 \n", | |
"=================================================================\n", | |
"Total params: 1,816,970\n", | |
"Trainable params: 1,815,178\n", | |
"Non-trainable params: 1,792\n", | |
"_________________________________________________________________\n", | |
"gen summary:\n", | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"dense_16 (Dense) (None, 6272) 457856 \n", | |
"_________________________________________________________________\n", | |
"reshape_4 (Reshape) (None, 7, 7, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"batch_normalization_22 (Batc (None, 7, 7, 128) 512 \n", | |
"_________________________________________________________________\n", | |
"up_sampling2d_7 (UpSampling2 (None, 14, 14, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"conv2d_26 (Conv2D) (None, 14, 14, 128) 147584 \n", | |
"_________________________________________________________________\n", | |
"activation_10 (Activation) (None, 14, 14, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"batch_normalization_23 (Batc (None, 14, 14, 128) 512 \n", | |
"_________________________________________________________________\n", | |
"up_sampling2d_8 (UpSampling2 (None, 28, 28, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"conv2d_27 (Conv2D) (None, 28, 28, 64) 73792 \n", | |
"_________________________________________________________________\n", | |
"activation_11 (Activation) (None, 28, 28, 64) 0 \n", | |
"_________________________________________________________________\n", | |
"batch_normalization_24 (Batc (None, 28, 28, 64) 256 \n", | |
"_________________________________________________________________\n", | |
"conv2d_28 (Conv2D) (None, 28, 28, 1) 577 \n", | |
"_________________________________________________________________\n", | |
"activation_12 (Activation) (None, 28, 28, 1) 0 \n", | |
"=================================================================\n", | |
"Total params: 681,089\n", | |
"Trainable params: 680,449\n", | |
"Non-trainable params: 640\n", | |
"_________________________________________________________________\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/ubuntu/miniconda3/lib/python3.6/site-packages/keras/engine/training.py:975: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?\n", | |
" 'Discrepancy between trainable weights and collected trainable'\n" | |
] | |
}, | |
{ | |
"ename": "KeyboardInterrupt", | |
"evalue": "", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-9-489e5ad1b538>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0minfogan\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mINFOGAN\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0minfogan\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m50000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_interval\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m50\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m<ipython-input-8-36aa60c01d3d>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, epochs, batch_size, sample_interval)\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[0;31m# ---------------------\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 198\u001b[0;31m \u001b[0mg_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcombined\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_on_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgen_input\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mvalid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled_labels\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;31m# Plot the progress\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mtrain_on_batch\u001b[0;34m(self, x, y, sample_weight, class_weight)\u001b[0m\n\u001b[1;32m 1881\u001b[0m \u001b[0mins\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msample_weights\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1882\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_train_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1883\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1884\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1885\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 2478\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2479\u001b[0m \u001b[0mfetches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdates_op\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetches\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2480\u001b[0;31m \u001b[0msession\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_session\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2481\u001b[0m updated = session.run(fetches=fetches, feed_dict=feed_dict,\n\u001b[1;32m 2482\u001b[0m **self.session_kwargs)\n", | |
"\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36mget_session\u001b[0;34m()\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;31m# not already marked as initialized.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 192\u001b[0m is_initialized = session.run(\n\u001b[0;32m--> 193\u001b[0;31m [tf.is_variable_initialized(v) for v in candidate_vars])\n\u001b[0m\u001b[1;32m 194\u001b[0m \u001b[0muninitialized_vars\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mflag\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_initialized\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcandidate_vars\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 899\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 900\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 901\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1133\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1134\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m-> 1135\u001b[0;31m feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[1;32m 1136\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1137\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1314\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1315\u001b[0m return self._do_call(_run_fn, feeds, fetches, targets, options,\n\u001b[0;32m-> 1316\u001b[0;31m run_metadata)\n\u001b[0m\u001b[1;32m 1317\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1318\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1320\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1321\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1322\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1323\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1324\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 1303\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_run_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1304\u001b[0m \u001b[0;31m# Ensure any changes to the graph are reflected in the runtime.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1305\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_extend_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1306\u001b[0m return self._call_tf_sessionrun(\n\u001b[1;32m 1307\u001b[0m options, feed_dict, fetch_list, target_list, run_metadata)\n", | |
"\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_extend_graph\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1338\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_created_with_new_api\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1339\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_graph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_lock\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1340\u001b[0;31m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mExtendSession\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1341\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1342\u001b[0m \u001b[0;31m# Ensure any changes to the graph are reflected in the runtime.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
], | |
"source": [ | |
"infogan = INFOGAN()\n", | |
"infogan.train(epochs=50000, batch_size=128, sample_interval=50)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"toc": { | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": "block", | |
"toc_window_display": false | |
}, | |
"varInspector": { | |
"cols": { | |
"lenName": 16, | |
"lenType": 16, | |
"lenVar": 40 | |
}, | |
"kernels_config": { | |
"python": { | |
"delete_cmd_postfix": "", | |
"delete_cmd_prefix": "del ", | |
"library": "var_list.py", | |
"varRefreshCmd": "print(var_dic_list())" | |
}, | |
"r": { | |
"delete_cmd_postfix": ") ", | |
"delete_cmd_prefix": "rm(", | |
"library": "var_list.r", | |
"varRefreshCmd": "cat(var_dic_list()) " | |
} | |
}, | |
"types_to_exclude": [ | |
"module", | |
"function", | |
"builtin_function_or_method", | |
"instance", | |
"_Feature" | |
], | |
"window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment