Skip to content

Instantly share code, notes, and snippets.

@GongXinyuu
Last active February 7, 2021 02:58
Show Gist options
  • Select an option

  • Save GongXinyuu/3ec4fd87cad87cadadfdb08f137299b2 to your computer and use it in GitHub Desktop.

Select an option

Save GongXinyuu/3ec4fd87cad87cadadfdb08f137299b2 to your computer and use it in GitHub Desktop.
GAN training
def train(args, gen_net: nn.Module, dis_net: nn.Module, gen_optimizer, dis_optimizer, gen_avg_param,
train_loader, epoch, writer_dict, image_counter=None, schedulers=None):
writer = writer_dict['writer']
gen_step = 0
# train mode
gen_net = gen_net.train()
dis_net = dis_net.train()
for iter_idx, (imgs, _) in enumerate(tqdm(train_loader)):
global_steps = writer_dict['train_global_steps']
if image_counter is not None:
image_counter.step(imgs.shape[0])
# Adversarial ground truths
real_imgs = imgs.type(torch.cuda.FloatTensor)
if iter_idx == 0:
# log real image
img_grid = make_grid(real_imgs[:25], nrow=5, normalize=True, scale_each=True)
writer.add_image('real_images', img_grid, epoch)
# Sample noise as generator input
z = torch.cuda.FloatTensor(np.random.normal(0, 1, (imgs.shape[0], args.latent_dim)))
# ---------------------
# Train Discriminator
# ---------------------
dis_optimizer.zero_grad()
real_validity = dis_net(real_imgs)
fake_imgs = gen_net(z).detach()
assert fake_imgs.size() == real_imgs.size()
fake_validity = dis_net(fake_imgs)
# cal loss
if args.loss == 'hinge':
d_loss = torch.mean(nn.ReLU(inplace=True)(1.0 - real_validity)) + \
torch.mean(nn.ReLU(inplace=True)(1 + fake_validity))
elif args.loss == 'wgangp':
## phi is set 750 in cifar10
gradient_penalty = compute_gradient_penalty(dis_net, real_imgs, fake_imgs.detach(), args.phi)
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty * args.lambda_gp / (
args.phi ** 2)
## drift_weight is set to 0.001 by default
d_loss += (torch.mean(real_validity) ** 2) * args.drift_weight
else:
raise NotImplementedError(args.loss)
d_loss.backward()
dis_optimizer.step()
# -----------------
# Train Generator
# -----------------
if global_steps % args.n_critic == 0:
gen_optimizer.zero_grad()
gen_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (args.gen_batch_size, args.latent_dim)))
gen_imgs = gen_net(gen_z)
fake_validity = dis_net(gen_imgs)
# cal loss
if args.loss == 'hinge' or args.loss == 'wgangp':
g_loss = -torch.mean(fake_validity)
else:
raise NotImplementedError(args.loss)
g_loss.backward()
gen_optimizer.step()
# adjust learning rate
if schedulers:
gen_scheduler, dis_scheduler = schedulers
g_lr = gen_scheduler.step(global_steps)
d_lr = dis_scheduler.step(global_steps)
writer.add_scalar('LR/g_lr', g_lr, global_steps)
writer.add_scalar('LR/d_lr', d_lr, global_steps)
# moving average weight
for p, avg_p in zip(
gen_net.net_parameters() if hasattr(gen_net, 'net_parameters') else gen_net.parameters(),
gen_avg_param):
avg_p.mul_(0.999).add_(0.001, p.data)
gen_step += 1
# verbose
if gen_step and iter_idx % args.print_freq == 0:
tqdm.write(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
(epoch, args.max_epoch, iter_idx % len(train_loader), len(train_loader), d_loss.item(), g_loss.item()))
writer.add_scalar('d_loss', d_loss.item(), global_steps)
writer.add_scalar('g_loss', g_loss.item(), global_steps)
writer_dict['train_global_steps'] = global_steps + 1
def compute_gradient_penalty(D, real_samples, fake_samples, phi):
"""Calculates the gradient penalty loss for WGAN GP"""
# Random weight term for interpolation between real and fake samples
alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1))).to(real_samples.device)
# Get random interpolation between real and fake samples
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = D(interpolates)
fake = torch.ones(real_samples.shape[0], requires_grad=False).to(real_samples.device)
# Get gradient w.r.t. interpolates
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - phi) ** 2).mean()
return gradient_penalty
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment