Last active
August 7, 2022 14:57
-
-
Save airalcorn2/c77ae441d6dbe23e6b6be39fd7042e43 to your computer and use it in GitHub Desktop.
Minimal example demonstrating how a variational autoencoder frequently generates unrealistic samples when optimized to learn a simple 2D bimodal distribution.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Adapted from: https://github.com/pytorch/examples/blob/main/vae/main.py. | |
import torch | |
import torch.utils.data | |
from torch import nn, optim | |
from torch.nn import functional as F | |
class VAE(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.fc1 = nn.Linear(2, 50) | |
self.fc21 = nn.Linear(50, 20) | |
self.fc22 = nn.Linear(50, 20) | |
self.fc3 = nn.Linear(20, 50) | |
self.fc4 = nn.Linear(50, 4) | |
def encode(self, x): | |
h1 = F.relu(self.fc1(x)) | |
return (self.fc21(h1), self.fc22(h1)) | |
def reparameterize(self, mu, logvar): | |
std = torch.exp(0.5 * logvar) | |
eps = torch.randn_like(std) | |
return mu + eps * std | |
def decode(self, z): | |
h3 = F.relu(self.fc3(z)) | |
return self.fc4(h3) | |
def forward(self, x): | |
(mu, logvar) = self.encode(x) | |
z = self.reparameterize(mu, logvar) | |
return (self.decode(z), mu, logvar) | |
class RNN(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.rnn = nn.RNN(1, 25) | |
self.linear = nn.Linear(25, 2) | |
def forward(self, x): | |
(hiddens, _) = self.rnn(x) | |
return self.linear(hiddens) | |
def loss_function(recon_x, y, mu=None, logvar=None): | |
# Reconstruction + KL divergence losses summed over all elements and batch. | |
recon_x = recon_x.reshape(-1, 2) | |
log_probs = F.log_softmax(recon_x, dim=1) | |
NLL = F.nll_loss(log_probs, y, reduction="sum") | |
KLD = 0 | |
if is_vae: | |
# See Appendix B from VAE paper: | |
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 | |
# https://arxiv.org/abs/1312.6114 | |
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) | |
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | |
return (NLL, KLD) | |
if __name__ == "__main__": | |
device = torch.device("cuda") | |
model = VAE().to(device) | |
print(model) | |
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print(f"Parameters: {n_params}") | |
optimizer = optim.Adam(model.parameters(), lr=1e-3) | |
X = torch.Tensor([[-1, -1], [1, 1]]).to(device) | |
y = torch.LongTensor([0, 0, 1, 1]).to(device) | |
epochs = 100 | |
updates = 1024 | |
is_vae = True | |
best_train_loss = float("inf") | |
patience = 5 | |
no_improvement = 0 | |
lr_drops = 0 | |
for epoch in range(1, epochs + 1): | |
model.train() | |
NLL_total = 0 | |
KLD_total = 0 | |
train_loss = 0 | |
for _ in range(updates): | |
optimizer.zero_grad() | |
(recon_batch, mu, logvar) = model(X) | |
(NLL, KLD) = loss_function(recon_batch, y, mu, logvar) | |
loss = NLL + KLD | |
loss.backward() | |
NLL_total += NLL.item() / len(X) | |
KLD_total += KLD.item() / len(X) | |
train_loss += loss.item() / len(X) | |
optimizer.step() | |
NLL_total /= updates | |
KLD_total /= updates | |
train_loss /= updates | |
if train_loss < best_train_loss: | |
best_train_loss = train_loss | |
no_improvement = 0 | |
else: | |
no_improvement += 1 | |
if no_improvement == patience: | |
lr_drops += 1 | |
if lr_drops == 2: | |
break | |
print("Reducing learning rate.") | |
no_improvement = 0 | |
for g in optimizer.param_groups: | |
g["lr"] *= 0.1 | |
print(f"====> Epoch: {epoch} Average loss: {train_loss:.4f}") | |
print(f"====> Epoch: {epoch} Best average loss: {best_train_loss:.4f}") | |
print(f"====> Epoch: {epoch} Average NLL: {NLL_total:.4f}") | |
print(f"====> Epoch: {epoch} Average KLD: {KLD_total:.4f}") | |
with torch.no_grad(): | |
samples = model.decode(torch.randn(500, 20).to(device)).cpu().reshape(-1, 2) | |
probs = torch.softmax(samples, dim=1) | |
samples = torch.multinomial(probs, 1).reshape(-1, 2).numpy() | |
both_zeros = ((samples[:, 0] == 0) & (samples[:, 1] == 0)).sum() | |
print(f"both_zeros %: {100 * both_zeros / len(samples)}") | |
both_ones = ((samples[:, 0] == 1) & (samples[:, 1] == 1)).sum() | |
print(f"both_ones %: {100 * both_ones / len(samples)}") | |
different = (samples[:, 0] != samples[:, 1]).sum() | |
print(f"different %: {100 * different / len(samples)}") | |
model = RNN().to(device) | |
print(model) | |
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print(f"Parameters: {n_params}") | |
optimizer = optim.Adam(model.parameters(), lr=1e-3) | |
X = torch.Tensor([[0, -1], [0, 1]]).to(device).T.unsqueeze(2) | |
is_vae = False | |
best_train_loss = float("inf") | |
no_improvement = 0 | |
lr_drops = 0 | |
for epoch in range(1, epochs + 1): | |
model.train() | |
train_loss = 0 | |
for _ in range(updates): | |
optimizer.zero_grad() | |
recon_batch = model(X).permute(1, 0, 2) | |
loss = loss_function(recon_batch, y)[0] | |
loss.backward() | |
train_loss += loss.item() / len(X) | |
optimizer.step() | |
train_loss /= updates | |
if train_loss < best_train_loss: | |
best_train_loss = train_loss | |
no_improvement = 0 | |
else: | |
no_improvement += 1 | |
if no_improvement == patience: | |
lr_drops += 1 | |
if lr_drops == 2: | |
break | |
print("Reducing learning rate.") | |
no_improvement = 0 | |
for g in optimizer.param_groups: | |
g["lr"] *= 0.1 | |
print(f"====> Epoch: {epoch} Average loss: {train_loss:.4f}") | |
print(f"====> Epoch: {epoch} Best average loss: {best_train_loss:.4f}") | |
with torch.no_grad(): | |
samples = [] | |
for sample in range(500): | |
X = torch.zeros(2, 1).to(device) | |
sample_vals = [] | |
for step in range(2): | |
preds = model(X.unsqueeze(0).permute(1, 0, 2)) | |
probs = torch.softmax(preds.permute(1, 0, 2).squeeze(0)[step], dim=0) | |
sample_val = torch.multinomial(probs, 1) | |
sample_vals.append(sample_val.item()) | |
if step == 0: | |
X[step + 1] = 2 * sample_val - 1 | |
samples.append(sample_vals) | |
samples = torch.Tensor(samples).cpu().numpy() | |
both_zeros = ((samples[:, 0] == 0) & (samples[:, 1] == 0)).sum() | |
print(f"both_zeros %: {100 * both_zeros / len(samples)}") | |
both_ones = ((samples[:, 0] == 1) & (samples[:, 1] == 1)).sum() | |
print(f"both_ones %: {100 * both_ones / len(samples)}") | |
different = (samples[:, 0] != samples[:, 1]).sum() | |
print(f"different %: {100 * different / len(samples)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Colab notebook.