Skip to content

Instantly share code, notes, and snippets.

@bmabir17
Last active January 14, 2023 02:12
Show Gist options
  • Save bmabir17/990762d11cd587c05ddfa211d07829b6 to your computer and use it in GitHub Desktop.
Save bmabir17/990762d11cd587c05ddfa211d07829b6 to your computer and use it in GitHub Desktop.
Resnet Variational autoencoder for image reconstruction
import torch
from torch import nn
import torch.nn.functional as F
import abc
import pytorch_ssim
import torchvision.models as models
from torch.autograd import Variable
class AbstractAutoEncoder(nn.Module):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def encode(self, x):
return
@abc.abstractmethod
def decode(self, z):
return
@abc.abstractmethod
def forward(self, x, latent_vec=False):
"""model return (reconstructed_x, *)"""
return
@abc.abstractmethod
def loss_function(self, **kwargs):
"""accepts (original images, *) where * is the same as returned from forward()"""
return
@abc.abstractmethod
def latest_losses(self):
"""returns the latest losses in a dictionary. Useful for logging."""
return
class ResNet_VAE(AbstractAutoEncoder):
def __init__(
self,recon_loss_type, fc_hidden1=1024,
fc_hidden2=768, drop_p=0.3, CNN_embed_dim=256):
super(ResNet_VAE, self).__init__()
self.recon_loss_type = recon_loss_type
self.fc_hidden1, self.fc_hidden2, self.CNN_embed_dim = fc_hidden1, fc_hidden2, CNN_embed_dim
# CNN architechtures
self.ch1, self.ch2, self.ch3, self.ch4 = 16, 32, 64, 128
self.k1, self.k2, self.k3, self.k4 = (5, 5), (3, 3), (3, 3), (3, 3) # 2d kernal size
self.s1, self.s2, self.s3, self.s4 = (2, 2), (2, 2), (2, 2), (2, 2) # 2d strides
self.pd1, self.pd2, self.pd3, self.pd4 = (0, 0), (0, 0), (0, 0), (0, 0) # 2d padding
# encoding components
resnet = models.resnet18(pretrained=True)
modules = list(resnet.children())[:-1] # delete the last fc layer.
self.resnet = nn.Sequential(*modules)
self.fc1 = nn.Linear(resnet.fc.in_features, self.fc_hidden1)
self.bn1 = nn.BatchNorm1d(self.fc_hidden1, momentum=0.01)
self.fc2 = nn.Linear(self.fc_hidden1, self.fc_hidden2)
self.bn2 = nn.BatchNorm1d(self.fc_hidden2, momentum=0.01)
# Latent vectors mu and sigma
self.fc3_mu = nn.Linear(self.fc_hidden2, self.CNN_embed_dim) # output = CNN embedding latent variables
self.fc3_logvar = nn.Linear(self.fc_hidden2, self.CNN_embed_dim) # output = CNN embedding latent variables
# Sampling vector
self.fc4 = nn.Linear(self.CNN_embed_dim, self.fc_hidden2)
self.fc_bn4 = nn.BatchNorm1d(self.fc_hidden2)
self.fc5 = nn.Linear(self.fc_hidden2, 64 * 4 * 4)
self.fc_bn5 = nn.BatchNorm1d(64 * 4 * 4)
self.relu = nn.ReLU(inplace=True)
# Decoder
self.convTrans9 = nn.Sequential(
nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=self.k4, stride=self.s4,
padding=self.pd4),
nn.BatchNorm2d(512, momentum=0.01),
nn.ReLU(inplace=True),
)
self.convTrans10 = nn.Sequential(
nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=self.k4, stride=self.s4,
padding=self.pd4),
nn.BatchNorm2d(256, momentum=0.01),
nn.ReLU(inplace=True),
)
self.convTrans11 = nn.Sequential(
nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=self.k3, stride=self.s3,
padding=self.pd3),
nn.BatchNorm2d(128, momentum=0.01),
nn.ReLU(inplace=True),
)
self.convTrans12 = nn.Sequential(
nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=self.k4, stride=self.s4,
padding=self.pd4),
nn.BatchNorm2d(64, momentum=0.01),
nn.ReLU(inplace=True),
)
self.convTrans6 = nn.Sequential(
nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=self.k4, stride=self.s4,
padding=self.pd4),
nn.BatchNorm2d(32, momentum=0.01),
nn.ReLU(inplace=True),
)
self.convTrans7 = nn.Sequential(
nn.ConvTranspose2d(in_channels=32, out_channels=8, kernel_size=self.k3, stride=self.s3,
padding=self.pd3),
nn.BatchNorm2d(8, momentum=0.01),
nn.ReLU(inplace=True),
)
self.convTrans8 = nn.Sequential(
nn.ConvTranspose2d(in_channels=8, out_channels=3, kernel_size=self.k2, stride=self.s2,
padding=self.pd2),
nn.BatchNorm2d(3, momentum=0.01),
nn.Sigmoid() # y = (y1, y2, y3) \in [0 ,1]^3
)
def encode(self, x):
x = self.resnet(x) # ResNet
x = x.view(x.size(0), -1) # flatten output of conv
# FC layers
x = self.bn1(self.fc1(x))
x = self.relu(x)
x = self.bn2(self.fc2(x))
x = self.relu(x)
# x = F.dropout(x, p=self.drop_p, training=self.training)
mu, logvar = self.fc3_mu(x), self.fc3_logvar(x)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(logvar/2)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
x = self.fc_bn4(self.fc4(z))
x = self.relu(x)
x = self.fc_bn5(self.fc5(x))
x = self.relu(x).view(-1, 1024, 1, 1)
x = self.convTrans9(x)
x = self.convTrans10(x)
x = self.convTrans11(x)
x = self.convTrans12(x)
x = self.convTrans6(x)
x = self.convTrans7(x)
x = self.convTrans8(x)
x = F.interpolate(x, size=(224, 224), mode='bilinear')
return x
def forward(self, x,latent_vec=False):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
x_reconst = self.decode(z)
if latent_vec:
return x_reconst, mu, logvar, z
else:
return x_reconst, mu, logvar
def loss_function(self, x, x_hat, mu, logvar):
recon_loss = calc_reconstruction_loss(x, x_hat, self.recon_loss_type)
kl_loss = -0.5 * torch.mean(1 + logvar - mu**2 - logvar.exp())
return kl_loss + recon_loss
class ResNet_CVAE(AbstractAutoEncoder):
def __init__(
self,recon_loss_type, fc_hidden1=1024,
fc_hidden2=768, drop_p=0.3, CNN_embed_dim=256):
super(ResNet_VAE, self).__init__()
self.recon_loss_type = recon_loss_type
self.fc_hidden1, self.fc_hidden2, self.CNN_embed_dim = fc_hidden1, fc_hidden2, CNN_embed_dim
# CNN architechtures
self.ch1, self.ch2, self.ch3, self.ch4 = 16, 32, 64, 128
self.k1, self.k2, self.k3, self.k4 = (5, 5), (3, 3), (3, 3), (3, 3) # 2d kernal size
self.s1, self.s2, self.s3, self.s4 = (2, 2), (2, 2), (2, 2), (2, 2) # 2d strides
self.pd1, self.pd2, self.pd3, self.pd4 = (0, 0), (0, 0), (0, 0), (0, 0) # 2d padding
# encoding components
resnet = models.resnet18(pretrained=True)
modules = list(resnet.children())[:-2] # delete the last fc layer.
self.resnet = nn.Sequential(*modules)
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1,
padding=self.pd4),
nn.BatchNorm2d(256, momentum=0.01),
nn.ReLU(inplace=True),
)
self.out_mu = nn.Conv2d(256, 256, kernel_size=1, stride=1)
self.out_logvar = nn.Conv2d(256, 256, kernel_size=1, stride=1)
# (256, 8, 8) -> (3, 256, 256)
self.decoder = nn.Sequential(
# (256, 8, 8) -> (256, 16, 16)
UpsamplingLayer(256, 256, activation="ReLU"),
# -> (128, 32, 32)
UpsamplingLayer(256, 128, activation="ReLU"),
# -> (64, 64, 64)
UpsamplingLayer(128, 64, activation="ReLU", type="upsample"),
# -> (32, 128, 128)
UpsamplingLayer(64, 32, activation="ReLU", bn=False),
# -> (3, 256, 256)
UpsamplingLayer(32, 3, activation="none", bn=False, type="upsample"),
# nn.Tanh()
nn.Hardtanh(-1.0, 1.0),
)
def encode(self, x):
x = self.resnet(x) # ResNet
x = self.conv1(x)
mu = self.out_mu(x)
logvar = self.out_logvar(x)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(logvar/2)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
return self.decoder(z)
def forward(self, x,latent_vec=False):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
x_reconst = self.decode(z)
if latent_vec:
return x_reconst, mu, logvar, z
else:
return x_reconst, mu, logvar
def loss_function(self, x, x_hat, mu, logvar):
recon_loss = calc_reconstruction_loss(x, x_hat, self.recon_loss_type)
kl_loss = -0.5 * torch.mean(1 + logvar - mu**2 - logvar.exp())
return kl_loss + recon_loss
class UpsamplingLayer(nn.Module):
def __init__(self, in_channel, out_channel, activation="none", bn=True, type="transpose"):
super(UpsamplingLayer, self).__init__()
self.bn = nn.BatchNorm2d(out_channel) if bn else None
if activation == "ReLU":
self.activaton = nn.ReLU(True)
elif activation == "none":
self.activaton = None
else:
assert()
if type == "transpose":
self.upsampler = nn.Sequential(
nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2, padding=0),
)
elif type == "upsample":
self.upsampler = nn.Sequential(
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
)
else:
assert()
def forward(self, x):
x = self.upsampler(x)
if self.activaton:
x = self.activaton(x)
if self.bn:
x = self.bn(x)
return x
import torch
from torch import nn
from torch.nn import functional as F
import abc
import pytorch_ssim
# Copyright 2018 The Sonnet Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Borrowed from https://github.com/deepmind/sonnet and ported it to PyTorch
class AbstractAutoEncoder(nn.Module):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def encode(self, x):
return
@abc.abstractmethod
def decode(self, z):
return
@abc.abstractmethod
def forward(self, x, latent_vec=False):
"""model return (reconstructed_x, *)"""
return
@abc.abstractmethod
def loss_function(self, **kwargs):
"""accepts (original images, *) where * is the same as returned from forward()"""
return
@abc.abstractmethod
def latest_losses(self):
"""returns the latest losses in a dictionary. Useful for logging."""
return
class Quantize(nn.Module):
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
super().__init__()
self.dim = dim
self.n_embed = n_embed
self.decay = decay
self.eps = eps
embed = torch.randn(dim, n_embed)
self.register_buffer('embed', embed)
self.register_buffer('cluster_size', torch.zeros(n_embed))
self.register_buffer('embed_avg', embed.clone())
def forward(self, input):
flatten = input.reshape(-1, self.dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed
+ self.embed.pow(2).sum(0, keepdim=True)
)
_, embed_ind = (-dist).max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
embed_ind = embed_ind.view(*input.shape[:-1])
quantize = self.embed_code(embed_ind)
if self.training:
self.cluster_size.data.mul_(self.decay).add_(
1 - self.decay, embed_onehot.sum(0)
)
embed_sum = flatten.transpose(0, 1) @ embed_onehot
self.embed_avg.data.mul_(self.decay).add_(1 - self.decay, embed_sum)
n = self.cluster_size.sum()
cluster_size = (
(self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
self.embed.data.copy_(embed_normalized)
diff = (quantize.detach() - input).pow(2).mean()
quantize = input + (quantize - input).detach()
return quantize, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.transpose(0, 1))
class ResBlock(nn.Module):
def __init__(self, in_channel, channel):
super().__init__()
self.conv = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(in_channel, channel, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel, in_channel, 1),
)
def forward(self, input):
out = self.conv(input)
out += input
return out
class Encoder(nn.Module):
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
super().__init__()
if stride == 4:
blocks = [
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel, channel, 3, padding=1),
]
elif stride == 2:
blocks = [
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 3, padding=1),
]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(nn.ReLU(inplace=True))
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class Decoder(nn.Module):
def __init__(
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
):
super().__init__()
blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(nn.ReLU(inplace=True))
if stride == 4:
blocks.extend(
[
nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(
channel // 2, out_channel, 4, stride=2, padding=1
),
]
)
elif stride == 2:
blocks.append(
nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
)
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class VQVAE(AbstractAutoEncoder):
def __init__(
self,
recon_loss_type,
in_channel=3,
channel=128,
n_res_block=2,
n_res_channel=32,
embed_dim=64,
n_embed=512,
decay=0.99,
):
super().__init__()
self.recon_loss_type = recon_loss_type
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4)
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
self.quantize_conv_t = nn.Conv2d(channel, embed_dim, 1)
self.quantize_t = Quantize(embed_dim, n_embed)
self.dec_t = Decoder(
embed_dim, embed_dim, channel, n_res_block, n_res_channel, stride=2
)
self.quantize_conv_b = nn.Conv2d(embed_dim + channel, embed_dim, 1)
self.quantize_b = Quantize(embed_dim, n_embed)
self.upsample_t = nn.ConvTranspose2d(
embed_dim, embed_dim, 4, stride=2, padding=1
)
self.dec = Decoder(
embed_dim + embed_dim,
in_channel,
channel,
n_res_block,
n_res_channel,
stride=4,
)
def forward(self, input, latent_vec=False):
quant_t, quant_b, diff, _, _ = self.encode(input)
dec, quant = self.decode(quant_t, quant_b)
if latent_vec:
return dec, diff, quant
else:
return dec, diff
def encode(self, input):
enc_b = self.enc_b(input)
enc_t = self.enc_t(enc_b)
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
quant_t, diff_t, id_t = self.quantize_t(quant_t)
quant_t = quant_t.permute(0, 3, 1, 2)
diff_t = diff_t.unsqueeze(0)
dec_t = self.dec_t(quant_t)
enc_b = torch.cat([dec_t, enc_b], 1)
quant_b = self.quantize_conv_b(enc_b).permute(0, 2, 3, 1)
quant_b, diff_b, id_b = self.quantize_b(quant_b)
quant_b = quant_b.permute(0, 3, 1, 2)
diff_b = diff_b.unsqueeze(0)
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
def decode(self, quant_t, quant_b):
upsample_t = self.upsample_t(quant_t)
quant = torch.cat([upsample_t, quant_b], 1)
dec = self.dec(quant)
return dec, quant
def decode_code(self, code_t, code_b):
quant_t = self.quantize_t.embed_code(code_t)
quant_t = quant_t.permute(0, 3, 1, 2)
quant_b = self.quantize_b.embed_code(code_b)
quant_b = quant_b.permute(0, 3, 1, 2)
dec = self.decode(quant_t, quant_b)
return dec
def loss_function(self, x, x_hat, diff, latent_loss_weight=0.25):
criterion = nn.MSELoss()
recon_loss = criterion(x_hat, x)
latent_loss = diff.mean()
loss = recon_loss + latent_loss_weight * latent_loss
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment