Created
March 29, 2018 08:53
-
-
Save lotabout/17ae1adb75b77f4553862903f3e3f963 to your computer and use it in GitHub Desktop.
WGAN implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import torch.nn as nn | |
import torch | |
import torchvision.datasets as datasets | |
import torchvision.transforms as transforms | |
from torch.autograd import Variable, grad | |
import torch.optim as optim | |
import torchvision.utils as vutils | |
from path import Path | |
class Config(object): | |
NUM_OF_GPU = 2 | |
USE_CUDA = torch.cuda.is_available() | |
Z_CHANNELS = 100 # number of noise | |
G_FEATURES = 64 # number of features used in GNet | |
D_FEATURES = 64 # number of features used in DNet | |
OUTPUT_CHANNELS = 3 | |
DATA_ROOT = './data' | |
IMAGE_SIZE = 96 | |
BATCH_SIZE = 64 | |
NUM_WORKERS = 8 | |
LR = 0.0002 # learning rate | |
BETA1 = 0.5 # beta1 for adam optimizer | |
EPOCHES = 5000 | |
EPOCHES_TO_SAVE = 20 | |
DEBUG_FOLDER = './debug' | |
LAMBDA = 10 | |
CRITIC_ITERS = 5 # only update generate every CRITIC_ITERS | |
PRE_TRAINED_G = None | |
PRE_TRAINED_D = None | |
import os | |
try: | |
os.makedirs(Config.DEBUG_FOLDER) | |
except OSError: | |
pass | |
# custom weights initialization called on netG and netD | |
def weights_init(module): | |
classname = module.__class__.__name__ | |
if classname.find('Conv') != -1: | |
module.weight.data.normal_(0.0, 0.02) | |
elif classname.find('BatchNorm') != -1: | |
module.weight.data.normal_(1.0, 0.02) | |
module.bias.data.fill_(0) | |
class AverageMeter(object): | |
def __init__(self): | |
self.reset() # __init__():reset parameters | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
class _NetG(nn.Module): | |
"""docstring for _NetG""" | |
def __init__(self, config): | |
super(_NetG, self).__init__() | |
self.num_of_gpu = config.NUM_OF_GPU | |
Z_CHANNELS = config.Z_CHANNELS | |
G_FEATURES = config.G_FEATURES | |
OUTPUT_CHANNELS = config.OUTPUT_CHANNELS | |
self.net = nn.Sequential( | |
#input is noise Z, going into a convolution (Z_CHANNELS * 1 * 1) | |
nn.ConvTranspose2d(Z_CHANNELS, G_FEATURES*8, kernel_size=4, bias=False), | |
nn.BatchNorm2d(G_FEATURES*8), | |
nn.ReLU(inplace=True), | |
# state size: (G_FEATURES*8) x 4 x 4 | |
nn.ConvTranspose2d(G_FEATURES*8, G_FEATURES*4, kernel_size=4, stride=2, padding=1, bias=False), | |
nn.BatchNorm2d(G_FEATURES*4), | |
nn.ReLU(inplace=True), | |
# state size: (G_FEATURES*4) x 8 x 8 | |
nn.ConvTranspose2d(G_FEATURES*4, G_FEATURES*2, kernel_size=4, stride=2, padding=1, bias=False), | |
nn.BatchNorm2d(G_FEATURES*2), | |
nn.ReLU(inplace=True), | |
# state size: (G_FEATURES*4) x 16 x 16 | |
nn.ConvTranspose2d(G_FEATURES*2, G_FEATURES, kernel_size=4, stride=2, padding=1, bias=False), | |
nn.BatchNorm2d(G_FEATURES), | |
nn.ReLU(inplace=True), | |
# state size: (G_FEATURES*4) x 32 x 32 | |
nn.ConvTranspose2d(G_FEATURES, OUTPUT_CHANNELS, kernel_size=5, stride=3, padding=1, bias=False), | |
nn.Tanh() | |
# state size: (OUTPUT_CHANNLE) x 96 x 96 | |
) | |
def forward(self, input): | |
if isinstance(input.data, torch.cuda.FloatTensor) and self.num_of_gpu > 1: | |
output = nn.parallel.data_parallel(self.net, input, range(self.num_of_gpu)) | |
else: | |
output = self.net(input) | |
return output | |
class _NetD(nn.Module): | |
def __init__(self, config): | |
super(_NetD, self).__init__() | |
self.num_of_gpu = config.NUM_OF_GPU | |
Z_CHANNELS = config.Z_CHANNELS | |
D_FEATURES = config.D_FEATURES | |
OUTPUT_CHANNELS = config.OUTPUT_CHANNELS | |
self.net = nn.Sequential( | |
# input is: (OUTPUT_CHANNLE) x 96 x 96 | |
nn.Conv2d(OUTPUT_CHANNELS, D_FEATURES, kernel_size=5, stride=3, padding=1, bias=False), | |
nn.LeakyReLU(0.2, inplace=True), | |
# state size: (D_FEATURES) x 32 x 32 | |
nn.Conv2d(D_FEATURES, D_FEATURES*2, kernel_size=4, stride=2, padding=1, bias=False), | |
# nn.BatchNorm2d(D_FEATURES*2), | |
nn.LeakyReLU(0.2, inplace=True), | |
# state size: (D_FEATURES*2) x 16 x 16 | |
nn.Conv2d(D_FEATURES*2, D_FEATURES*4, kernel_size=4, stride=2, padding=1, bias=False), | |
# nn.BatchNorm2d(D_FEATURES*4), | |
nn.LeakyReLU(0.2, inplace=True), | |
# state size: (D_FEATURES*4) x 8 x 8 | |
nn.Conv2d(D_FEATURES*4, D_FEATURES*8, kernel_size=4, stride=2, padding=1, bias=False), | |
# nn.BatchNorm2d(D_FEATURES*8), | |
nn.LeakyReLU(0.2, inplace=True), | |
# state size: (D_FEATURES*8) x 4 x 4 | |
nn.Conv2d(D_FEATURES*8, 1, kernel_size=4, stride=1, padding=0, bias=False), | |
) | |
def forward(self, input): | |
if isinstance(input.data, torch.cuda.FloatTensor) and self.num_of_gpu > 1: | |
output = nn.parallel.data_parallel(self.net, input, range(self.num_of_gpu)) | |
else: | |
output = self.net(input) | |
return output.view(-1, 1).squeeze(1) | |
def cal_gradient_penalty(netD, real_data, fake_data, config=Config()): | |
alpha = torch.rand(config.BATCH_SIZE, 1) | |
alpha = alpha.expand(config.BATCH_SIZE, real_data.nelement()//config.BATCH_SIZE).contiguous().view(real_data.size()) | |
alpha = alpha.cuda() if config.USE_CUDA else alpha | |
interpolates = alpha * real_data + ((1-alpha) * fake_data) | |
if config.USE_CUDA: | |
interpolates = interpolates.cuda() | |
interpolates = Variable(interpolates, requires_grad=True) | |
disc_interpolates = netD(interpolates) | |
grad_outputs = torch.ones(disc_interpolates.size()) | |
grad_outputs = grad_outputs.cuda() if config.USE_CUDA else grad_outputs | |
gradients = grad(disc_interpolates, interpolates, grad_outputs=grad_outputs, | |
create_graph=True, retain_graph=True, only_inputs=True)[0] | |
return ((gradients.norm(2, dim=1)-1) ** 2).mean() | |
def train(config=Config(), **kwargs): | |
transformers = transforms.Compose([ | |
transforms.Resize(config.IMAGE_SIZE), | |
transforms.CenterCrop(config.IMAGE_SIZE), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
dataset = datasets.ImageFolder(root=config.DATA_ROOT, transform=transformers) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=True, | |
num_workers=config.NUM_WORKERS) | |
netG = _NetG(config) | |
if config.PRE_TRAINED_G is not None: | |
netG.load_state_dict(torch.load(config.PRE_TRAINED_G)) | |
else: | |
netG.apply(weights_init) | |
netD = _NetD(config) | |
if config.PRE_TRAINED_D is not None: | |
netD.load_state_dict(torch.load(CONFID.PRE_TRAINED_D)) | |
else: | |
netD.apply(weights_init) | |
# prepare input for GNet | |
input = torch.FloatTensor(config.BATCH_SIZE, 3, config.IMAGE_SIZE, config.IMAGE_SIZE) | |
noise = torch.FloatTensor(config.BATCH_SIZE, config.Z_CHANNELS, 1, 1) | |
fixed_noise = torch.FloatTensor(config.BATCH_SIZE, config.Z_CHANNELS, 1, 1).normal_(0, 1) | |
one = torch.FloatTensor([1]) | |
mone = one * -1 | |
if config.USE_CUDA: | |
netD.cuda() | |
netG.cuda() | |
input = input.cuda() | |
noise, fixed_noise = noise.cuda(), fixed_noise.cuda() | |
one, mone = one.cuda(), mone.cuda() | |
fixed_noise = Variable(fixed_noise) | |
# setup optimizer | |
optimizerG = optim.Adam(netG.parameters(), lr=config.LR, betas=(config.BETA1, 0.999)) | |
optimizerD = optim.Adam(netD.parameters(), lr=config.LR, betas=(config.BETA1, 0.999)) | |
for epoch in range(config.EPOCHES): | |
Wasserstein_Ds = AverageMeter() | |
D_costs = AverageMeter() | |
for batch_idx, (images, _) in enumerate(dataloader): | |
#================================================== | |
# update D network | |
# train with real | |
netD.zero_grad() | |
batch_size = images.size(0) | |
if config.USE_CUDA: | |
images = images.cuda() | |
input.resize_as_(images).copy_(images) | |
inputv = Variable(input) | |
D_real = netD(inputv) | |
D_real = D_real.mean() | |
# train with fake | |
noise.resize_(batch_size, config.Z_CHANNELS, 1, 1).normal_(0, 1) | |
noisev = Variable(noise) | |
fake = netG(noisev) | |
D_fake = netD(fake.detach()) # detach so that netG won't be affected by D_fake.backward() | |
D_fake = D_fake.mean() | |
penalty = cal_gradient_penalty(netD, images, fake.data, config=config) * config.LAMBDA | |
D_cost = D_fake - D_real + penalty | |
D_cost.backward() | |
D_costs.update(D_cost.data[0]) | |
Wasserstein_D = D_real - D_fake | |
Wasserstein_Ds.update(Wasserstein_D.data[0]) | |
optimizerD.step() | |
#================================================== | |
# update G network: maximize log(D(G(z))) | |
if (batch_idx + epoch * len(dataloader)) % config.CRITIC_ITERS == 0: | |
# only update generator every CRITIC_ITERS | |
netG.zero_grad() | |
noise.resize_(batch_size, config.Z_CHANNELS, 1, 1).normal_(0, 1) | |
noisev = Variable(noise) | |
fake = netG(noisev) | |
G = netD(fake) | |
G = -G.mean() | |
G.backward() | |
optimizerG.step() | |
if batch_idx % config.CRITIC_ITERS == 0: | |
print(f'[{epoch}/{config.EPOCHES}] [{batch_idx}/{len(dataloader)}] ' | |
f'Wasserstein_D: {Wasserstein_D.data[0]:.4f}/{Wasserstein_Ds.avg:.4f} ' | |
f'Loss: {-D_cost.data[0]:.4f}/{-D_costs.avg:.4f}') | |
if batch_idx % 100 == 0: | |
vutils.save_image(images[:64], f'{config.DEBUG_FOLDER}/real_samples.png', normalize=True) | |
fake = netG(fixed_noise) | |
vutils.save_image(fake.data[:64], f'{config.DEBUG_FOLDER}/fake_samples_epoch_{epoch:03d}.png', normalize=True) | |
# do checkpointing | |
if (epoch+1) % config.EPOCHES_TO_SAVE == 0: | |
torch.save(netG.state_dict(), f'{config.DEBUG_FOLDER}/netG_epoch_{epoch}.pth') | |
torch.save(netD.state_dict(), f'{config.DEBUG_FOLDER}/netD_epoch_{epoch}.pth') | |
config = Config() | |
train(config=config) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment