Last active
April 7, 2019 09:34
-
-
Save piojanu/7f847c63366f3eaba1d792139f02dec1 to your computer and use it in GitHub Desktop.
TensorFlow Probability MNIST VAE implementation using tf_utils (https://github.com/piojanu/tf_utils/blob/master/tf_utils/utils.py)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
import tensorflow as tf | |
import tensorflow_probability as tfp | |
from tf_utils import AttrDict, lazy_property_with_scope | |
tfd = tfp.distributions | |
tfl = tf.layers | |
class Model(object): | |
def __init__(self, data, config): | |
self.data = data | |
self.data_shape = list(self.data.shape[1:]) | |
self.config = config | |
self.prior | |
self.posterior | |
self.code | |
self.likelihood | |
self.samples | |
self.loss | |
self.optimise | |
@lazy_property_with_scope | |
def prior(self): | |
"""Standard normal distribution prior.""" | |
return tfd.MultivariateNormalDiag( | |
loc=tf.zeros(self.config.code_size), | |
scale_diag=tf.ones(self.config.code_size)) | |
@lazy_property_with_scope(scope_name="encoder") | |
def posterior(self): | |
"""a.k.a the encoder""" | |
x = tfl.Flatten()(self.data) | |
x = tfl.Dense(self.config.hidden_size, activation='relu')(x) | |
x = tfl.Dense(self.config.hidden_size, activation='relu')(x) | |
loc = tfl.Dense(self.config.code_size)(x) | |
scale = tfl.Dense(self.config.code_size, activation='softplus')(x) | |
return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) | |
@lazy_property_with_scope | |
def code(self): | |
"""Sample code from the posterior.""" | |
return self.posterior.sample() | |
@lazy_property_with_scope(scope_name="decoder", reuse=tf.AUTO_REUSE) | |
def likelihood(self): | |
"""a.k.a the decoder.""" | |
return self._make_decoder(self.code) | |
@lazy_property_with_scope(scope_name="decoder", reuse=tf.AUTO_REUSE) | |
def samples(self): | |
"""Generate examples.""" | |
return self._make_decoder(self.prior.sample(self.config.n_samples)).mean() | |
@lazy_property_with_scope | |
def loss(self): | |
"""Negative evidence lower bound reduced over the whole batch and every pixel.""" | |
elbo = self.likelihood.log_prob(self.data) - tfd.kl_divergence(self.posterior, self.prior) | |
return -tf.reduce_mean(elbo) | |
@lazy_property_with_scope | |
def optimise(self): | |
"""ADAM optimiser for the loss (negative ELBO).""" | |
return tf.train.AdamOptimizer(self.config.learning_rate).minimize(self.loss) | |
def _make_decoder(self, code): | |
x = tfl.Dense(self.config.hidden_size, activation='relu')(code) | |
x = tfl.Dense(self.config.hidden_size, activation='relu')(x) | |
logits = tfl.Dense(np.product(self.data_shape))(x) | |
logits = tf.reshape(logits, [-1] + self.data_shape) | |
return tfd.Independent(tfd.Bernoulli(logits), 2) | |
def plot_codes(ax, codes, labels): | |
ax.scatter(codes[:, 0], codes[:, 1], s=2, c=labels, alpha=0.1) | |
ax.set_aspect('equal') | |
ax.set_xlim(codes.min() - .1, codes.max() + .1) | |
ax.set_ylim(codes.min() - .1, codes.max() + .1) | |
ax.tick_params( | |
axis='both', which='both', left=False, bottom=False, | |
labelleft=False, labelbottom=False) | |
def plot_samples(ax, samples): | |
for index, sample in enumerate(samples): | |
ax[index].imshow(sample, cmap='gray') | |
ax[index].axis('off') | |
def create_datasets(train_set, test_set): | |
train_dataset = tf.data.Dataset.from_tensor_slices( | |
tf.convert_to_tensor(train_set, dtype=tf.float32)) \ | |
.map(lambda x: x / 255) \ | |
.shuffle(train_set.shape[0]) \ | |
.batch(config.batch_size) | |
test_dataset = tf.data.Dataset.from_tensor_slices( | |
tf.convert_to_tensor(test_set, dtype=tf.float32)) \ | |
.map(lambda x: x / 255) \ | |
.batch(test_set.shape[0]) | |
iterator = tf.data.Iterator.from_structure(train_dataset.output_types, | |
train_dataset.output_shapes) | |
next_batch = iterator.get_next() | |
train_init_op = iterator.make_initializer(train_dataset) | |
test_init_op = iterator.make_initializer(test_dataset) | |
return next_batch, train_init_op, test_init_op | |
def train(model, train_init_op, test_init_op, test_labels, config): | |
_, ax = plt.subplots(nrows=config.epochs, ncols=config.n_samples + 1, figsize=(10, 20)) | |
with tf.train.MonitoredSession() as sess: | |
for epoch in range(config.epochs): | |
# Test | |
sess.run(test_init_op) | |
test_loss, test_codes, test_samples = sess.run([model.loss, model.code, model.samples]) | |
# Plot | |
ax[epoch, 0].set_ylabel('Epoch {}'.format(epoch)) | |
plot_codes(ax[epoch, 0], test_codes, test_labels) | |
plot_samples(ax[epoch, 1:], test_samples) | |
# Train | |
train_losses = [] | |
sess.run(train_init_op) | |
while True: | |
try: | |
_, train_loss = sess.run([model.optimise, model.loss]) | |
train_losses.append(train_loss) | |
except tf.errors.OutOfRangeError: | |
break | |
# Log | |
print('Epoch: {:2d}/{:2d}, train loss: {:.3f}, test loss: {:.3f}'.format( | |
epoch + 1, config.epochs, np.mean(train_losses), test_loss)) | |
plt.savefig('vae-mnist.png', dpi=300, transparent=True, bbox_inches='tight') | |
if __name__ == "__main__": | |
config = AttrDict({ | |
"batch_size": 100, | |
"epochs": 20, | |
"n_samples": 10, | |
"code_size": 2, | |
"hidden_size": 200, | |
"learning_rate": 0.001 | |
}) | |
(train_set, _), (test_set, test_labels) = tf.keras.datasets.mnist.load_data() | |
train_set, test_set, test_labels = train_set[:], test_set[:2000], test_labels[:2000] # DEBUG | |
next_batch, train_init_op, test_init_op = create_datasets(train_set, test_set) | |
model = Model(next_batch, config) | |
train(model, train_init_op, test_init_op, test_labels, config) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Inspired by: https://danijar.com/building-variational-auto-encoders-in-tensorflow/