Created
May 30, 2022 11:51
-
-
Save gabrieldernbach/f0dd70a32e037f191c60e01f5390ee75 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
from torchvision.datasets import MNIST | |
from torch.utils.data import DataLoader | |
import torchvision.transforms as T | |
from einops import rearrange | |
class AE(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.encoder = nn.Sequential( | |
nn.Conv2d(1, 8, 4, 2, 1), # (1, 28, 28) -> (64, 14, 14) | |
nn.ReLU(), | |
nn.Conv2d(8, 512, 4, 2, 1), # (64, 14, 14) -> (64, 7, 7) | |
nn.ReLU(), | |
nn.Conv2d(512, 64, 3, 2, 1), # (64, 7, 7) -> (64, 4, 4) | |
) | |
self.decoder = nn.Sequential( | |
nn.ConvTranspose2d(64, 512, 3, 2, 1), # (64, 4, 4) -> (64, 7, 7) | |
nn.ReLU(), | |
nn.ConvTranspose2d(512, 8, 4, 2, 1), # (64, 7, 7) -> (64, 14, 14) | |
nn.ReLU(), | |
nn.ConvTranspose2d(8, 1, 4, 2, 1) # (64, 14, 14) -> (1, 28, 28) | |
) | |
def loss(self, x): | |
enc = self.encoder(x) | |
dec = self.decoder(enc) | |
return F.mse_loss(x, dec) | |
class VQ(nn.Module): | |
def __init__(self, n_emb=512, emb_dim=64, commit_cost=0.25): | |
super().__init__() | |
self.n_emb = n_emb | |
self.emb_dim = emb_dim | |
self.emb = nn.Embedding(self.n_emb, self.emb_dim) | |
self.emb.weight.data.uniform_(-1/n_emb, 1/self.n_emb) | |
self.commit_cost = commit_cost | |
def forward(self, inputs): | |
b, c, h, w = inputs.shape | |
inputs_ = rearrange(inputs, "b c h w -> (b h w) c") | |
idx = torch.cdist(inputs_, self.emb.weight).argmin(1) | |
quantized_ = self.emb.weight[idx] | |
quantized = rearrange(quantized_, "(b h w) c -> b c h w", b=b, h=h) | |
q_loss = F.mse_loss(inputs, quantized.detach()) | |
e_loss = F.mse_loss(inputs.detach(), quantized) | |
qeloss = q_loss + self.commit_cost * e_loss | |
quantized = inputs + (quantized - inputs).detach() | |
return qeloss, quantized, idx | |
class VQVAE(AE): | |
def __init__(self): | |
super().__init__() | |
self.vq = VQ() | |
def loss(self, x): | |
enc = self.encoder(x) | |
qloss, quant, idx = self.vq(enc) | |
dec = self.decoder(quant) | |
return F.mse_loss(x, dec) + qloss | |
tfm = T.Compose([ | |
T.RandomRotation(360), | |
T.ToTensor(), | |
T.Normalize((0.1307,), (0.3081,)), | |
]) | |
ds = MNIST('./', download=True, transform=tfm) | |
dl = DataLoader(ds, batch_size=512, shuffle=True, drop_last=True) | |
vqvae = VQVAE().train() | |
optim = torch.optim.Adam(vqvae.parameters(), lr=1e-3) | |
for epoch in range(10): | |
loss_avg = 0 | |
for i, (x, _) in enumerate(tqdm(dl), start=1): | |
loss = vqvae.loss(x) | |
loss.backward() | |
optim.step() | |
optim.zero_grad() | |
#print(loss.item()) | |
loss_avg += (loss.item() - loss_avg) / i | |
print(epoch, loss_avg) | |
out = vqvae.decoder(vqvae.vq(vqvae.encoder(x))[1]) | |
toplot = torch.stack([x, out]) | |
toplot = rearrange(toplot, 't (b1 b2) 1 h w -> (b1 h) (b2 t w)', b1=32).detach() | |
plt.figure(figsize=(13, 13)) | |
plt.imshow(toplot) | |
plt.axis("off") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment