Created
November 12, 2018 01:14
-
-
Save clungzta/1040100eea5194f6a911b88c82ade66b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import time | |
| import pickle | |
| import logging | |
| import argparse | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.python.keras.models import Model | |
| from tensorflow.python.keras import backend as K | |
| from tensorflow.python.keras.utils import plot_model | |
| from tensorflow.python.keras.layers import Lambda, Input, Dense | |
| # from tensorflow.python.keras.losses import mse, binary_crossentropy | |
| from colorlog import ColoredFormatter | |
| LR = 1e-4 | |
| N_ITERS = 10000 | |
| BATCH_SIZE = 500 | |
| # reparameterization trick | |
| # instead of sampling from Q(z|X) | |
| # z = z_mean + sqrt(var)*N(0,I) | |
| def sampling(args): | |
| """Reparameterization trick by sampling fr an isotropic unit Gaussian. | |
| # Arguments: | |
| args (tensor): mean and log of variance of Q(z|X) | |
| # Returns: | |
| z (tensor): sampled latent vector | |
| """ | |
| z_mean, z_log_var = args | |
| batch = K.shape(z_mean)[0] | |
| dim = K.int_shape(z_mean)[1] | |
| # by default, random_normal has mean=0 and std=1.0 | |
| epsilon = K.random_normal(shape=(batch, dim)) | |
| return z_mean + K.exp(0.5 * z_log_var) * epsilon | |
| def vae(original_dim, intermediate_dims=[8192, 512], latent_dim=16, mse=True): | |
| # VAE model = encoder + decoder | |
| # build encoder model | |
| inputs = Input(shape=(original_dim,), name='encoder_input') | |
| for count, num_neurons in enumerate(intermediate_dims): | |
| x = Dense(num_neurons, activation=tf.nn.leaky_relu, name='enc_intermediate{}'.format(count))(inputs) | |
| z_mean = Dense(latent_dim, name='z_mean')(x) | |
| z_log_var = Dense(latent_dim, name='z_log_var')(x) | |
| # use reparameterization trick to push the sampling out as input | |
| # note that "output_shape" isn't necessary with the TensorFlow backend | |
| z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var]) | |
| # instantiate encoder model | |
| encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder') | |
| encoder.summary() | |
| plot_model(encoder, to_file='vae_mlp_encoder.png', show_shapes=True) | |
| # build decoder model | |
| latent_inputs = Input(shape=(latent_dim,), name='z_sampling') | |
| for count, num_neurons in enumerate(intermediate_dims): | |
| x = Dense(num_neurons, activation=tf.nn.leaky_relu, name='dec_intermediate{}'.format(count))(latent_inputs) | |
| outputs = Dense(original_dim, activation='sigmoid')(x) | |
| # instantiate decoder model | |
| decoder = Model(latent_inputs, outputs, name='decoder') | |
| decoder.summary() | |
| plot_model(decoder, to_file='vae_mlp_decoder.png', show_shapes=True) | |
| # instantiate VAE model | |
| outputs = decoder(encoder(inputs)[2]) | |
| model = Model(inputs, outputs, name='vae_mlp') | |
| # VAE loss = (mse_loss or xent_loss) + kl_loss | |
| if mse: | |
| reconstruction_loss = tf.keras.losses.mse(inputs, outputs) | |
| else: | |
| reconstruction_loss = tf.keras.losses.binary_crossentropy(inputs, outputs) | |
| reconstruction_loss *= original_dim | |
| kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var) | |
| kl_loss = K.sum(kl_loss, axis=-1) | |
| kl_loss *= -0.5 | |
| vae_loss = K.mean(reconstruction_loss + kl_loss) | |
| return model, vae_loss | |
| if __name__ == '__main__': | |
| log_level = logging.INFO | |
| logging.root.setLevel(log_level) | |
| formatter = ColoredFormatter("%(log_color)s%(levelname)s:%(message)s %(reset)s") | |
| stream = logging.StreamHandler() | |
| stream.setLevel(log_level) | |
| stream.setFormatter(formatter) | |
| log = logging.getLogger('pythonConfig') | |
| log.setLevel(log_level) | |
| log.addHandler(stream) | |
| sift_vlad_k, sift_vlad_niters = 128, 100 | |
| vlad_path = '/home/am893/frog_reidentification/hpc_output/vlad_{}_{}.npy'.format(sift_vlad_k, sift_vlad_niters) | |
| vlad_labels_path = '/home/am893/frog_reidentification/hpc_output/vlad_{}_{}.pkl'.format(sift_vlad_k, sift_vlad_niters) | |
| if os.path.exists(vlad_path): | |
| # VLAD data exists on filesystem, load it | |
| vlad_vectors = np.load(vlad_path) | |
| log.info(vlad_vectors.shape) | |
| with open(vlad_labels_path, 'rb') as f: | |
| vlad_vector_labels = np.asarray(pickle.load(f)) | |
| else: | |
| raise Exception('VLAD file does not exist') | |
| graph = tf.Graph() | |
| with graph.as_default(): | |
| global_step = tf.Variable(0, trainable=False) | |
| with tf.name_scope("learning_rate"): | |
| learning_rate = tf.train.exponential_decay(LR, global_step, N_ITERS, 0.8, staircase=True) | |
| with tf.variable_scope('variational_autoencoder', reuse=tf.AUTO_REUSE): | |
| original_dim = vlad_vectors.shape[0] | |
| x = tf.placeholder(tf.float32, [batch_size, original_dim], name='X_placeholder') | |
| vae_model, vae_loss = vae(original_dim) | |
| mem_use = get_model_memory_usage(BATCH_SIZE, vae_model) | |
| gpu_avail_mem = 15.13 # GB | |
| mem_use_proportion = mem_use / gpu_avail_mem | |
| outstr = 'With a batch size of {} the model memory usage (per GPU) is: {:.2f}GB/{:.2f}GB ({:.0%})'.format(BATCH_SIZE, mem_use, gpu_avail_mem, mem_use_proportion) | |
| if mem_use_proportion > 0.9: | |
| log.critical(outstr) | |
| elif mem_use_proportion: | |
| log.warning(outstr) | |
| opt = tf.train.AdamOptimizer(learning_rate) | |
| train_op = opt.minimize(vae_loss, global_step) | |
| with tf.Session(graph=graph, config=_config) as sess: | |
| for v in tf.global_variables(): | |
| log.info(v.name) | |
| sess.run(tf.global_variables_initializer()) | |
| for i in range(N_ITERS): | |
| rand_index = np.random.choice(vlad_vectors.shape[0], size=batch_size) | |
| batch_x = vlad_vectors[rand_index] | |
| _, loss = sess.run([train_op, vae_loss], feed_dict={x: batch_x}) | |
| log.info('Iter {}, Loss: {}'.format(i, loss)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment