Last active
February 7, 2021 02:58
-
-
Save GongXinyuu/3ec4fd87cad87cadadfdb08f137299b2 to your computer and use it in GitHub Desktop.
GAN training
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def train(args, gen_net: nn.Module, dis_net: nn.Module, gen_optimizer, dis_optimizer, gen_avg_param, | |
| train_loader, epoch, writer_dict, image_counter=None, schedulers=None): | |
| writer = writer_dict['writer'] | |
| gen_step = 0 | |
| # train mode | |
| gen_net = gen_net.train() | |
| dis_net = dis_net.train() | |
| for iter_idx, (imgs, _) in enumerate(tqdm(train_loader)): | |
| global_steps = writer_dict['train_global_steps'] | |
| if image_counter is not None: | |
| image_counter.step(imgs.shape[0]) | |
| # Adversarial ground truths | |
| real_imgs = imgs.type(torch.cuda.FloatTensor) | |
| if iter_idx == 0: | |
| # log real image | |
| img_grid = make_grid(real_imgs[:25], nrow=5, normalize=True, scale_each=True) | |
| writer.add_image('real_images', img_grid, epoch) | |
| # Sample noise as generator input | |
| z = torch.cuda.FloatTensor(np.random.normal(0, 1, (imgs.shape[0], args.latent_dim))) | |
| # --------------------- | |
| # Train Discriminator | |
| # --------------------- | |
| dis_optimizer.zero_grad() | |
| real_validity = dis_net(real_imgs) | |
| fake_imgs = gen_net(z).detach() | |
| assert fake_imgs.size() == real_imgs.size() | |
| fake_validity = dis_net(fake_imgs) | |
| # cal loss | |
| if args.loss == 'hinge': | |
| d_loss = torch.mean(nn.ReLU(inplace=True)(1.0 - real_validity)) + \ | |
| torch.mean(nn.ReLU(inplace=True)(1 + fake_validity)) | |
| elif args.loss == 'wgangp': | |
| ## phi is set 750 in cifar10 | |
| gradient_penalty = compute_gradient_penalty(dis_net, real_imgs, fake_imgs.detach(), args.phi) | |
| d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty * args.lambda_gp / ( | |
| args.phi ** 2) | |
| ## drift_weight is set to 0.001 by default | |
| d_loss += (torch.mean(real_validity) ** 2) * args.drift_weight | |
| else: | |
| raise NotImplementedError(args.loss) | |
| d_loss.backward() | |
| dis_optimizer.step() | |
| # ----------------- | |
| # Train Generator | |
| # ----------------- | |
| if global_steps % args.n_critic == 0: | |
| gen_optimizer.zero_grad() | |
| gen_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (args.gen_batch_size, args.latent_dim))) | |
| gen_imgs = gen_net(gen_z) | |
| fake_validity = dis_net(gen_imgs) | |
| # cal loss | |
| if args.loss == 'hinge' or args.loss == 'wgangp': | |
| g_loss = -torch.mean(fake_validity) | |
| else: | |
| raise NotImplementedError(args.loss) | |
| g_loss.backward() | |
| gen_optimizer.step() | |
| # adjust learning rate | |
| if schedulers: | |
| gen_scheduler, dis_scheduler = schedulers | |
| g_lr = gen_scheduler.step(global_steps) | |
| d_lr = dis_scheduler.step(global_steps) | |
| writer.add_scalar('LR/g_lr', g_lr, global_steps) | |
| writer.add_scalar('LR/d_lr', d_lr, global_steps) | |
| # moving average weight | |
| for p, avg_p in zip( | |
| gen_net.net_parameters() if hasattr(gen_net, 'net_parameters') else gen_net.parameters(), | |
| gen_avg_param): | |
| avg_p.mul_(0.999).add_(0.001, p.data) | |
| gen_step += 1 | |
| # verbose | |
| if gen_step and iter_idx % args.print_freq == 0: | |
| tqdm.write( | |
| "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % | |
| (epoch, args.max_epoch, iter_idx % len(train_loader), len(train_loader), d_loss.item(), g_loss.item())) | |
| writer.add_scalar('d_loss', d_loss.item(), global_steps) | |
| writer.add_scalar('g_loss', g_loss.item(), global_steps) | |
| writer_dict['train_global_steps'] = global_steps + 1 | |
| def compute_gradient_penalty(D, real_samples, fake_samples, phi): | |
| """Calculates the gradient penalty loss for WGAN GP""" | |
| # Random weight term for interpolation between real and fake samples | |
| alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1))).to(real_samples.device) | |
| # Get random interpolation between real and fake samples | |
| interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) | |
| d_interpolates = D(interpolates) | |
| fake = torch.ones(real_samples.shape[0], requires_grad=False).to(real_samples.device) | |
| # Get gradient w.r.t. interpolates | |
| gradients = torch.autograd.grad( | |
| outputs=d_interpolates, | |
| inputs=interpolates, | |
| grad_outputs=fake, | |
| create_graph=True, | |
| retain_graph=True, | |
| only_inputs=True, | |
| )[0] | |
| gradients = gradients.view(gradients.size(0), -1) | |
| gradient_penalty = ((gradients.norm(2, dim=1) - phi) ** 2).mean() | |
| return gradient_penalty |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment