Skip to content

Instantly share code, notes, and snippets.

@NewJerseyStyle
Created March 4, 2020 04:06
Show Gist options
  • Save NewJerseyStyle/766f8e0258472603376fe02870e0dfab to your computer and use it in GitHub Desktop.
Save NewJerseyStyle/766f8e0258472603376fe02870e0dfab to your computer and use it in GitHub Desktop.
A VAE class in pytorch
# train or use
from dataloader import dataloader
import model
import torch
import torch.nn as nn
from tqdm import trange, tqdm
class vae_module(object):
"""docstring for vae_module"""
def __init__(self, num_latent, state_size, img_trans=None,
dataset_folder="images", dataset_format=".jpg",
filter_size=[3, 3, 3, 3, 3], channel_in=3,
train=True):
super(vae_module, self).__init__()
# init trainloader
if train:
trainset = dataloader(dataset_folder,
dataset_format, img_trans)
self.trainloader = torch.utils.data.DataLoader(trainset,
batch_size=32, shuffle=True)
self.model = model.VAE(num_latent,
state_size, filter_size,
channel_in)
def train(self, iters=26, print_every=5, print_func=None):
#print after every 5 iterations
device = ('cuda' if torch.cuda.is_available() else 'cpu')
import torch.optim as optim
optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
self._train(iters, device, optimizer, print_every, print_func)
######The function which we will call for training our model
def _train(self, iters, device, optimizer, print_every, print_f=None):
counter = 0
for i in trange(iters):
self.model.train()
self.model.to(device)
for images in tqdm(self.trainloader):
images = images.to(device)
optimizer.zero_grad()
out, mean, logvar = self.model(images)
loss = self.VAE_loss(out, images, mean, logvar)
loss.backward()
optimizer.step()
if(counter % print_every == 0):
self.model.eval()
if print_f:
print_f(loss.data.cpu().sum().numpy())
else:
# print("loss.sum(): ", loss.data.cpu().sum().numpy())
tqdm.write("loss.sum(): ", loss.data.cpu().sum().numpy())
counter += 1
def VAE_loss(self, out, target, mean, logvar):
category1 = nn.BCELoss()
# print("out: ", out.data.cpu().sum().numpy())
# print("target: ", target.data.cpu().sum().numpy())
tqdm.write("out: %s" %out.data.cpu().sum().numpy())
tqdm.write("target: %s" %target.data.cpu().sum().numpy())
bce_loss = category1(out, target)
# # print("BCELoss: ", bce_loss.data.cpu().sum().numpy())
# # print("MSELoss: ", nn.MSELoss()(out.float(), target.float()).data.cpu().numpy())
# tqdm.write("BCELoss: %s" %bce_loss.data.cpu().sum().numpy())
# tqdm.write("MSELoss: %s" %nn.MSELoss()(out.float(), target.float()).data.cpu().numpy())
#We will scale the following losses with this factor
scaling_factor = out.shape[0]*out.shape[1]*out.shape[2]*out.shape[3]
####Now we are gonna define the KL divergence loss
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kl_loss = -0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar))
kl_loss /= scaling_factor
return bce_loss + kl_loss
def save(self, path="vae.pt"):
torch.save(self.model.state_dict(), path)
def load(self, path="vae.pt"):
if torch.cuda.is_available():
self.model.load_state_dict(torch.load(path))
else:
self.model.load_state_dict(torch.load(path, map_location='cpu'))
self.model.eval()
def encode(self, data):
m, l = self.model.enc_func(data)
return self.model.get_hidden(m, l)
def decode(self, data):
return self.model.dec_func(data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment