Last active
March 30, 2021 19:56
-
-
Save previtus/5ec19eb31bbd21e4ff9275999f633e66 to your computer and use it in GitHub Desktop.
VAE experiments
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Based on implementations | |
# - vae core https://github.com/pytorch/examples/blob/master/vae/main.py | |
# - miwae https://github.com/yoonholee/pytorch-vae | |
# - notes on VAE from the article at https://iopscience.iop.org/article/10.3847/PSJ/ab9a52 (but can be taken from elsewhere too) | |
from __future__ import print_function | |
import argparse | |
import torch | |
import torch.utils.data | |
from torch import nn, optim | |
from torch.nn import functional as F | |
from torchvision import datasets, transforms | |
from torchvision.utils import save_image | |
from torch.distributions.bernoulli import Bernoulli | |
from torch.distributions.normal import Normal | |
from PIL import Image | |
import numpy as np | |
parser = argparse.ArgumentParser(description='VAE MNIST Example') | |
parser.add_argument('--batch-size', type=int, default=20, metavar='N', | |
help='input batch size for training (default: 20)') | |
parser.add_argument('--epochs', type=int, default=4000, metavar='N', | |
help='number of epochs to train (default: 10)') | |
parser.add_argument('--no-cuda', action='store_true', default=False, | |
help='disables CUDA training') | |
parser.add_argument('--seed', type=int, default=1, metavar='S', | |
help='random seed (default: 1)') | |
parser.add_argument('--log-interval', type=int, default=20, metavar='N', | |
help='how many batches to wait before logging training status') | |
parser.add_argument('--k', type=int, default=1) | |
parser.add_argument('--M', type=int, default=1) | |
args = parser.parse_args() | |
args.cuda = not args.no_cuda and torch.cuda.is_available() | |
args.log_interval = 1 | |
torch.manual_seed(args.seed) | |
device = torch.device("cuda" if args.cuda else "cpu") | |
print("runnning on", device) | |
path = "./MNIST" | |
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} | |
class stochMNIST(datasets.MNIST): | |
""" Gets a new stochastic binarization of MNIST at each call. """ | |
def __getitem__(self, index): | |
if self.train: | |
img, target = self.train_data[index], self.train_labels[index] | |
else: | |
img, target = self.test_data[index], self.test_labels[index] | |
img = Image.fromarray(img.numpy(), mode='L') | |
img = transforms.ToTensor()(img) | |
img = torch.bernoulli(img) # stochastically binarize | |
return img, target | |
def get_mean_img(self): | |
imgs = self.train_data.type(torch.float) / 255 | |
mean_img = imgs.mean(0).reshape(-1).numpy() | |
return mean_img | |
train_loader = torch.utils.data.DataLoader( | |
stochMNIST(path, train=True, download=True,transform=transforms.ToTensor()),batch_size=args.batch_size, shuffle=True, **kwargs) | |
test_loader = torch.utils.data.DataLoader( | |
stochMNIST(path, train=False, transform=transforms.ToTensor()),batch_size=args.batch_size, shuffle=True, **kwargs) | |
def debug_shape(item): | |
return item.cpu().detach().numpy().shape | |
class VAE(nn.Module): | |
def __init__(self, hidden_size = 400, latent_size = 20): | |
super(VAE, self).__init__() | |
# encoder layers | |
self.fc11 = nn.Linear(784, hidden_size) | |
self.fc12 = nn.Linear(hidden_size, hidden_size) | |
self.fc21 = nn.Linear(hidden_size, latent_size) | |
self.fc22 = nn.Linear(hidden_size, latent_size) | |
# decoder layers | |
self.fc31 = nn.Linear(latent_size, hidden_size) | |
self.fc32 = nn.Linear(hidden_size, hidden_size) | |
self.fc4 = nn.Linear(hidden_size, 784) | |
self.hidden_size = hidden_size | |
self.latent_size = latent_size | |
self.prior_distribution = Normal(torch.zeros([self.latent_size]).to(device), torch.ones([self.latent_size]).to(device)) | |
def encode(self, x): | |
x = F.tanh(self.fc11(x)) | |
x = F.tanh(self.fc12(x)) | |
mu_enc = self.fc21(x) | |
std_enc = self.fc22(x) | |
return Normal(mu_enc, F.softplus(std_enc)) | |
def reparameterize(self, mu, logvar): | |
std = torch.exp(0.5*logvar) | |
eps = torch.randn_like(std) | |
return mu + eps*std | |
def decode(self, z): | |
x = F.tanh(self.fc31(z)) | |
x = F.tanh(self.fc32(x)) | |
x = self.fc4(x) | |
return Bernoulli(logits=x) | |
def forward(self, x, M, k): | |
input_x = x.view(-1, 784).to(device) | |
# encoded distribution ~ q(z|x, params) = Normal (real input_x; encoder_into_Mu, encoder_into_Std ) | |
z_distribution = self.encode(input_x) | |
# sample z values from this distribution | |
z = z_distribution.rsample(torch.Size([M, k])) | |
# reconstructions distribution ~ p(x|z, params) = Normal/Bernoulli (sampled z) | |
x_distribution = self.decode(z) | |
# priors distribution ~ p(z) = Normal (sampled z; 0s, 1s ) | |
#self.prior_distribution = Normal(torch.zeros([self.latent_size]).to(device), torch.ones([self.latent_size]).to(device)) | |
elbo = self.elbo(input_x, z, x_distribution, z_distribution) # mean_n, imp_n, batch_size | |
elbo_iwae = self.logmeanexp(elbo, 1).squeeze(1) # mean_n, batch_size | |
loss = - torch.mean(elbo_iwae, 0) # batch_size | |
return x_distribution.probs, elbo, loss | |
def logmeanexp(self, inputs, dim=1): # *** | |
if inputs.size(dim) == 1: | |
return inputs | |
else: | |
input_max = inputs.max(dim, keepdim=True)[0] | |
return (inputs - input_max).exp().mean(dim).log() + input_max | |
def elbo(self, input_x, z, x_distribution, z_distribution): | |
lpxz = x_distribution.log_prob(input_x).sum(-1) | |
lpz = self.prior_distribution.log_prob(z).sum(-1) | |
lqzx = z_distribution.log_prob(z).sum(-1) | |
kl = -lpz + lqzx | |
return -kl + lpxz | |
args.log_interval = 500 | |
M = args.M | |
k = args.k | |
#M = 5 | |
#k = 5 | |
model = VAE().to(device) | |
optimizer = optim.Adam(model.parameters(), lr=1e-3) | |
def train(epoch): | |
model.train() | |
train_loss = 0 | |
for batch_idx, (data, _) in enumerate(train_loader): | |
data = data.to(device) | |
optimizer.zero_grad() | |
_, elbo, loss_mk = model(data, M, k) | |
loss = loss_mk.mean() | |
loss.backward() | |
train_loss += loss.item() | |
optimizer.step() | |
if batch_idx % args.log_interval == 0: | |
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | |
epoch, batch_idx * len(data), len(train_loader.dataset), | |
100. * batch_idx / len(train_loader), | |
loss.item() )) # / len(data) | |
def test(epoch): | |
#print_metrics = ((epoch-1) % 10) == 0 | |
print_metrics = True | |
if print_metrics: | |
model.eval() | |
with torch.no_grad(): | |
# Tests: | |
# IWAE with k, IWAE with 64, IWAE with 5000 | |
elbos = [] | |
for data, _ in test_loader: | |
_, elbo, _ = model(data, M=1, k=5000) | |
elbos.append(elbo.squeeze(0)) | |
elbos = np.asarray(elbos) | |
k_to_run = [k, 64, 5000] | |
all_losses = [] | |
for k_for_loss in k_to_run: | |
losses = [] | |
for elbo in elbos[:k_for_loss]: | |
losses.append(model.logmeanexp(elbo, 0).cpu().numpy().flatten()) | |
loss = np.concatenate(losses).mean() | |
all_losses.append(- loss) | |
test_loss_iwae_k, test_loss_iwae64, test_loss_iwae5000 = all_losses | |
print('====>Test metrics: IWAE M=', M, ',k=',k, ' || epoch', epoch) | |
print("IWAE-64: ", test_loss_iwae64) | |
print("logˆp(x) = IWAE-5000: ", test_loss_iwae5000) | |
print("−KL(Q||P): ", test_loss_iwae64-test_loss_iwae5000) | |
print("---------------") | |
if __name__ == "__main__": | |
for epoch in range(1, args.epochs + 1): | |
train(epoch) | |
test(epoch) | |
with torch.no_grad(): | |
sample = torch.randn(64, 20).to(device) | |
sample = model.decode(sample).probs.cpu() | |
save_image(sample.view(64, 1, 28, 28), 'results/sample_epoch' + str(epoch).zfill(4) + '.png') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Train Epoch: 1 [0/60000 (0%)] Loss: 544.618103 | |
Train Epoch: 1 [10000/60000 (17%)] Loss: 138.734665 | |
Train Epoch: 1 [20000/60000 (33%)] Loss: 115.618584 | |
Train Epoch: 1 [30000/60000 (50%)] Loss: 116.206688 | |
Train Epoch: 1 [40000/60000 (67%)] Loss: 111.551384 | |
Train Epoch: 1 [50000/60000 (83%)] Loss: 120.980362 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 1 | |
IWAE-64: 101.856346 | |
logˆp(x) = IWAE-5000: 101.97784 | |
−KL(Q||P): -0.12149048 | |
--------------- | |
Train Epoch: 2 [0/60000 (0%)] Loss: 101.694847 | |
Train Epoch: 2 [10000/60000 (17%)] Loss: 94.400818 | |
Train Epoch: 2 [20000/60000 (33%)] Loss: 108.299316 | |
Train Epoch: 2 [30000/60000 (50%)] Loss: 101.753235 | |
Train Epoch: 2 [40000/60000 (67%)] Loss: 104.659843 | |
Train Epoch: 2 [50000/60000 (83%)] Loss: 99.216331 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 2 | |
IWAE-64: 97.00398 | |
logˆp(x) = IWAE-5000: 97.345924 | |
−KL(Q||P): -0.34194183 | |
--------------- | |
Train Epoch: 3 [0/60000 (0%)] Loss: 106.367607 | |
Train Epoch: 3 [10000/60000 (17%)] Loss: 102.621948 | |
Train Epoch: 3 [20000/60000 (33%)] Loss: 93.247398 | |
Train Epoch: 3 [30000/60000 (50%)] Loss: 109.849731 | |
Train Epoch: 3 [40000/60000 (67%)] Loss: 105.828445 | |
Train Epoch: 3 [50000/60000 (83%)] Loss: 93.767998 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 3 | |
IWAE-64: 95.24419 | |
logˆp(x) = IWAE-5000: 95.411156 | |
−KL(Q||P): -0.1669693 | |
--------------- | |
Train Epoch: 4 [0/60000 (0%)] Loss: 97.471848 | |
Train Epoch: 4 [10000/60000 (17%)] Loss: 103.686646 | |
Train Epoch: 4 [20000/60000 (33%)] Loss: 102.596367 | |
Train Epoch: 4 [30000/60000 (50%)] Loss: 93.631889 | |
Train Epoch: 4 [40000/60000 (67%)] Loss: 90.186600 | |
Train Epoch: 4 [50000/60000 (83%)] Loss: 100.661491 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 4 | |
IWAE-64: 94.75698 | |
logˆp(x) = IWAE-5000: 94.39016 | |
−KL(Q||P): 0.3668213 | |
--------------- | |
Train Epoch: 5 [0/60000 (0%)] Loss: 109.656487 | |
Train Epoch: 5 [10000/60000 (17%)] Loss: 89.555992 | |
Train Epoch: 5 [20000/60000 (33%)] Loss: 97.195396 | |
Train Epoch: 5 [30000/60000 (50%)] Loss: 100.248428 | |
Train Epoch: 5 [40000/60000 (67%)] Loss: 104.410034 | |
Train Epoch: 5 [50000/60000 (83%)] Loss: 104.687523 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 5 | |
IWAE-64: 94.646225 | |
logˆp(x) = IWAE-5000: 93.69805 | |
−KL(Q||P): 0.9481735 | |
--------------- | |
Train Epoch: 6 [0/60000 (0%)] Loss: 97.765373 | |
Train Epoch: 6 [10000/60000 (17%)] Loss: 107.476028 | |
Train Epoch: 6 [20000/60000 (33%)] Loss: 97.607529 | |
Train Epoch: 6 [30000/60000 (50%)] Loss: 105.302513 | |
Train Epoch: 6 [40000/60000 (67%)] Loss: 109.760330 | |
Train Epoch: 6 [50000/60000 (83%)] Loss: 96.512207 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 6 | |
IWAE-64: 92.74505 | |
logˆp(x) = IWAE-5000: 92.943184 | |
−KL(Q||P): -0.19813538 | |
--------------- | |
Train Epoch: 7 [0/60000 (0%)] Loss: 92.339188 | |
Train Epoch: 7 [10000/60000 (17%)] Loss: 98.588173 | |
Train Epoch: 7 [20000/60000 (33%)] Loss: 89.058235 | |
Train Epoch: 7 [30000/60000 (50%)] Loss: 88.806847 | |
Train Epoch: 7 [40000/60000 (67%)] Loss: 96.309105 | |
Train Epoch: 7 [50000/60000 (83%)] Loss: 94.803154 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 7 | |
IWAE-64: 91.760635 | |
logˆp(x) = IWAE-5000: 92.285355 | |
−KL(Q||P): -0.52471924 | |
--------------- | |
Train Epoch: 8 [0/60000 (0%)] Loss: 87.517845 | |
Train Epoch: 8 [10000/60000 (17%)] Loss: 99.885033 | |
Train Epoch: 8 [20000/60000 (33%)] Loss: 104.214409 | |
Train Epoch: 8 [30000/60000 (50%)] Loss: 97.933716 | |
Train Epoch: 8 [40000/60000 (67%)] Loss: 99.270409 | |
Train Epoch: 8 [50000/60000 (83%)] Loss: 100.278252 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 8 | |
IWAE-64: 90.84507 | |
logˆp(x) = IWAE-5000: 92.03044 | |
−KL(Q||P): -1.1853714 | |
--------------- | |
Train Epoch: 9 [0/60000 (0%)] Loss: 105.112419 | |
Train Epoch: 9 [10000/60000 (17%)] Loss: 91.309120 | |
Train Epoch: 9 [20000/60000 (33%)] Loss: 96.311066 | |
Train Epoch: 9 [30000/60000 (50%)] Loss: 103.695045 | |
Train Epoch: 9 [40000/60000 (67%)] Loss: 102.628288 | |
Train Epoch: 9 [50000/60000 (83%)] Loss: 94.594231 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 9 | |
IWAE-64: 92.28972 | |
logˆp(x) = IWAE-5000: 91.84398 | |
−KL(Q||P): 0.44573975 | |
--------------- | |
Train Epoch: 10 [0/60000 (0%)] Loss: 102.444267 | |
Train Epoch: 10 [10000/60000 (17%)] Loss: 98.669945 | |
Train Epoch: 10 [20000/60000 (33%)] Loss: 91.118675 | |
Train Epoch: 10 [30000/60000 (50%)] Loss: 96.950302 | |
Train Epoch: 10 [40000/60000 (67%)] Loss: 107.136940 | |
Train Epoch: 10 [50000/60000 (83%)] Loss: 97.390648 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 10 | |
IWAE-64: 91.120186 | |
logˆp(x) = IWAE-5000: 91.54648 | |
−KL(Q||P): -0.42629242 | |
--------------- | |
Train Epoch: 11 [0/60000 (0%)] Loss: 109.360893 | |
Train Epoch: 11 [10000/60000 (17%)] Loss: 104.534805 | |
Train Epoch: 11 [20000/60000 (33%)] Loss: 104.689880 | |
Train Epoch: 11 [30000/60000 (50%)] Loss: 103.057434 | |
Train Epoch: 11 [40000/60000 (67%)] Loss: 105.310524 | |
Train Epoch: 11 [50000/60000 (83%)] Loss: 92.356544 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 11 | |
IWAE-64: 91.86742 | |
logˆp(x) = IWAE-5000: 91.24765 | |
−KL(Q||P): 0.61976624 | |
--------------- | |
Train Epoch: 12 [0/60000 (0%)] Loss: 94.402412 | |
Train Epoch: 12 [10000/60000 (17%)] Loss: 102.237709 | |
Train Epoch: 12 [20000/60000 (33%)] Loss: 92.341049 | |
Train Epoch: 12 [30000/60000 (50%)] Loss: 89.968994 | |
Train Epoch: 12 [40000/60000 (67%)] Loss: 91.344337 | |
Train Epoch: 12 [50000/60000 (83%)] Loss: 99.439751 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 12 | |
IWAE-64: 89.93662 | |
logˆp(x) = IWAE-5000: 91.165276 | |
−KL(Q||P): -1.228653 | |
--------------- | |
Train Epoch: 13 [0/60000 (0%)] Loss: 91.478836 | |
Train Epoch: 13 [10000/60000 (17%)] Loss: 94.909088 | |
Train Epoch: 13 [20000/60000 (33%)] Loss: 91.767891 | |
Train Epoch: 13 [30000/60000 (50%)] Loss: 92.367569 | |
Train Epoch: 13 [40000/60000 (67%)] Loss: 107.229668 | |
Train Epoch: 13 [50000/60000 (83%)] Loss: 98.232750 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 13 | |
IWAE-64: 91.38028 | |
logˆp(x) = IWAE-5000: 90.99059 | |
−KL(Q||P): 0.38968658 | |
--------------- | |
Train Epoch: 14 [0/60000 (0%)] Loss: 90.363869 | |
Train Epoch: 14 [10000/60000 (17%)] Loss: 99.742142 | |
Train Epoch: 14 [20000/60000 (33%)] Loss: 91.261124 | |
Train Epoch: 14 [30000/60000 (50%)] Loss: 90.453880 | |
Train Epoch: 14 [40000/60000 (67%)] Loss: 98.580307 | |
Train Epoch: 14 [50000/60000 (83%)] Loss: 99.148628 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 14 | |
IWAE-64: 90.65437 | |
logˆp(x) = IWAE-5000: 90.89895 | |
−KL(Q||P): -0.2445755 | |
--------------- | |
Train Epoch: 15 [0/60000 (0%)] Loss: 108.186623 | |
Train Epoch: 15 [10000/60000 (17%)] Loss: 92.393219 | |
Train Epoch: 15 [20000/60000 (33%)] Loss: 100.103477 | |
Train Epoch: 15 [30000/60000 (50%)] Loss: 85.533005 | |
Train Epoch: 15 [40000/60000 (67%)] Loss: 103.622581 | |
Train Epoch: 15 [50000/60000 (83%)] Loss: 102.047340 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 15 | |
IWAE-64: 90.0911 | |
logˆp(x) = IWAE-5000: 90.891205 | |
−KL(Q||P): -0.80010223 | |
--------------- | |
Train Epoch: 16 [0/60000 (0%)] Loss: 98.122261 | |
Train Epoch: 16 [10000/60000 (17%)] Loss: 92.934647 | |
Train Epoch: 16 [20000/60000 (33%)] Loss: 85.830734 | |
Train Epoch: 16 [30000/60000 (50%)] Loss: 95.870377 | |
Train Epoch: 16 [40000/60000 (67%)] Loss: 93.688805 | |
Train Epoch: 16 [50000/60000 (83%)] Loss: 90.419800 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 16 | |
IWAE-64: 89.59951 | |
logˆp(x) = IWAE-5000: 90.65003 | |
−KL(Q||P): -1.0505219 | |
--------------- | |
Train Epoch: 17 [0/60000 (0%)] Loss: 93.840065 | |
Train Epoch: 17 [10000/60000 (17%)] Loss: 86.847694 | |
Train Epoch: 17 [20000/60000 (33%)] Loss: 98.986687 | |
Train Epoch: 17 [30000/60000 (50%)] Loss: 98.521729 | |
Train Epoch: 17 [40000/60000 (67%)] Loss: 99.243057 | |
Train Epoch: 17 [50000/60000 (83%)] Loss: 91.025291 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 17 | |
IWAE-64: 90.11647 | |
logˆp(x) = IWAE-5000: 90.590324 | |
−KL(Q||P): -0.47385406 | |
--------------- | |
Train Epoch: 18 [0/60000 (0%)] Loss: 94.464935 | |
Train Epoch: 18 [10000/60000 (17%)] Loss: 99.852882 | |
Train Epoch: 18 [20000/60000 (33%)] Loss: 91.386147 | |
Train Epoch: 18 [30000/60000 (50%)] Loss: 90.344818 | |
Train Epoch: 18 [40000/60000 (67%)] Loss: 92.691124 | |
Train Epoch: 18 [50000/60000 (83%)] Loss: 97.712929 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 18 | |
IWAE-64: 90.838585 | |
logˆp(x) = IWAE-5000: 90.65541 | |
−KL(Q||P): 0.18317413 | |
--------------- | |
Train Epoch: 19 [0/60000 (0%)] Loss: 99.118088 | |
Train Epoch: 19 [10000/60000 (17%)] Loss: 105.104935 | |
Train Epoch: 19 [20000/60000 (33%)] Loss: 94.164665 | |
Train Epoch: 19 [30000/60000 (50%)] Loss: 100.436256 | |
Train Epoch: 19 [40000/60000 (67%)] Loss: 90.244896 | |
Train Epoch: 19 [50000/60000 (83%)] Loss: 86.268738 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 19 | |
IWAE-64: 89.80083 | |
logˆp(x) = IWAE-5000: 90.411835 | |
−KL(Q||P): -0.6110077 | |
--------------- | |
Train Epoch: 20 [0/60000 (0%)] Loss: 105.900833 | |
Train Epoch: 20 [10000/60000 (17%)] Loss: 85.296181 | |
Train Epoch: 20 [20000/60000 (33%)] Loss: 102.006134 | |
Train Epoch: 20 [30000/60000 (50%)] Loss: 91.458534 | |
Train Epoch: 20 [40000/60000 (67%)] Loss: 98.606804 | |
Train Epoch: 20 [50000/60000 (83%)] Loss: 92.486732 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 20 | |
IWAE-64: 89.921814 | |
logˆp(x) = IWAE-5000: 90.39562 | |
−KL(Q||P): -0.4738083 | |
--------------- | |
Train Epoch: 21 [0/60000 (0%)] Loss: 86.889793 | |
Train Epoch: 21 [10000/60000 (17%)] Loss: 93.808105 | |
Train Epoch: 21 [20000/60000 (33%)] Loss: 85.814552 | |
Train Epoch: 21 [30000/60000 (50%)] Loss: 97.433723 | |
Train Epoch: 21 [40000/60000 (67%)] Loss: 92.292229 | |
Train Epoch: 21 [50000/60000 (83%)] Loss: 84.512245 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 21 | |
IWAE-64: 90.70811 | |
logˆp(x) = IWAE-5000: 90.39505 | |
−KL(Q||P): 0.31305695 | |
--------------- | |
Train Epoch: 22 [0/60000 (0%)] Loss: 97.888206 | |
Train Epoch: 22 [10000/60000 (17%)] Loss: 95.112480 | |
Train Epoch: 22 [20000/60000 (33%)] Loss: 96.822960 | |
Train Epoch: 22 [30000/60000 (50%)] Loss: 105.579887 | |
Train Epoch: 22 [40000/60000 (67%)] Loss: 88.926628 | |
Train Epoch: 22 [50000/60000 (83%)] Loss: 83.429054 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 22 | |
IWAE-64: 90.45231 | |
logˆp(x) = IWAE-5000: 90.28255 | |
−KL(Q||P): 0.16976166 | |
--------------- | |
Train Epoch: 23 [0/60000 (0%)] Loss: 89.223228 | |
Train Epoch: 23 [10000/60000 (17%)] Loss: 93.890137 | |
Train Epoch: 23 [20000/60000 (33%)] Loss: 93.568741 | |
Train Epoch: 23 [30000/60000 (50%)] Loss: 88.926697 | |
Train Epoch: 23 [40000/60000 (67%)] Loss: 92.509758 | |
Train Epoch: 23 [50000/60000 (83%)] Loss: 99.818192 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 23 | |
IWAE-64: 89.77586 | |
logˆp(x) = IWAE-5000: 90.075615 | |
−KL(Q||P): -0.29975128 | |
--------------- | |
Train Epoch: 24 [0/60000 (0%)] Loss: 82.467995 | |
Train Epoch: 24 [10000/60000 (17%)] Loss: 95.007713 | |
Train Epoch: 24 [20000/60000 (33%)] Loss: 102.897850 | |
Train Epoch: 24 [30000/60000 (50%)] Loss: 103.482498 | |
Train Epoch: 24 [40000/60000 (67%)] Loss: 94.505943 | |
Train Epoch: 24 [50000/60000 (83%)] Loss: 97.068161 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 24 | |
IWAE-64: 90.61882 | |
logˆp(x) = IWAE-5000: 90.10386 | |
−KL(Q||P): 0.51496124 | |
--------------- | |
Train Epoch: 25 [0/60000 (0%)] Loss: 88.563004 | |
Train Epoch: 25 [10000/60000 (17%)] Loss: 96.062202 | |
Train Epoch: 25 [20000/60000 (33%)] Loss: 91.589104 | |
Train Epoch: 25 [30000/60000 (50%)] Loss: 100.115807 | |
Train Epoch: 25 [40000/60000 (67%)] Loss: 97.718956 | |
Train Epoch: 25 [50000/60000 (83%)] Loss: 92.590294 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 25 | |
IWAE-64: 90.77629 | |
logˆp(x) = IWAE-5000: 90.21303 | |
−KL(Q||P): 0.56326294 | |
--------------- | |
Train Epoch: 26 [0/60000 (0%)] Loss: 90.206627 | |
Train Epoch: 26 [10000/60000 (17%)] Loss: 95.104202 | |
Train Epoch: 26 [20000/60000 (33%)] Loss: 99.151428 | |
Train Epoch: 26 [30000/60000 (50%)] Loss: 93.590454 | |
Train Epoch: 26 [40000/60000 (67%)] Loss: 92.422302 | |
Train Epoch: 26 [50000/60000 (83%)] Loss: 103.758888 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 26 | |
IWAE-64: 90.73986 | |
logˆp(x) = IWAE-5000: 90.08644 | |
−KL(Q||P): 0.6534195 | |
--------------- | |
Train Epoch: 27 [0/60000 (0%)] Loss: 98.630524 | |
Train Epoch: 27 [10000/60000 (17%)] Loss: 84.656273 | |
Train Epoch: 27 [20000/60000 (33%)] Loss: 102.395241 | |
Train Epoch: 27 [30000/60000 (50%)] Loss: 103.834000 | |
Train Epoch: 27 [40000/60000 (67%)] Loss: 86.922234 | |
Train Epoch: 27 [50000/60000 (83%)] Loss: 111.384987 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 27 | |
IWAE-64: 89.31232 | |
logˆp(x) = IWAE-5000: 89.94471 | |
−KL(Q||P): -0.6323929 | |
--------------- | |
Train Epoch: 28 [0/60000 (0%)] Loss: 90.933304 | |
Train Epoch: 28 [10000/60000 (17%)] Loss: 99.818108 | |
Train Epoch: 28 [20000/60000 (33%)] Loss: 87.769615 | |
Train Epoch: 28 [30000/60000 (50%)] Loss: 94.958702 | |
Train Epoch: 28 [40000/60000 (67%)] Loss: 93.918137 | |
Train Epoch: 28 [50000/60000 (83%)] Loss: 99.295448 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 28 | |
IWAE-64: 89.74451 | |
logˆp(x) = IWAE-5000: 90.16274 | |
−KL(Q||P): -0.41823578 | |
--------------- | |
Train Epoch: 29 [0/60000 (0%)] Loss: 88.518005 | |
Train Epoch: 29 [10000/60000 (17%)] Loss: 92.207855 | |
Train Epoch: 29 [20000/60000 (33%)] Loss: 100.995888 | |
Train Epoch: 29 [30000/60000 (50%)] Loss: 81.873978 | |
Train Epoch: 29 [40000/60000 (67%)] Loss: 106.279015 | |
Train Epoch: 29 [50000/60000 (83%)] Loss: 98.464409 | |
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 29 | |
IWAE-64: 88.7935 | |
logˆp(x) = IWAE-5000: 90.025635 | |
−KL(Q||P): -1.232132 | |
--------------- |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment