Skip to content

Instantly share code, notes, and snippets.

@SubhadityaMukherjee
Created February 13, 2020 06:46
Show Gist options
  • Save SubhadityaMukherjee/a0f3ca4d7945e811b69b5f8c776d1533 to your computer and use it in GitHub Desktop.
Save SubhadityaMukherjee/a0f3ca4d7945e811b69b5f8c776d1533 to your computer and use it in GitHub Desktop.
train
img_list = []
G_losses = []
D_losses = []
iters = 0
for epoch in range(num_epochs):

    for i, data in enumerate(dataloader, 0):
    	# Part 1
        netD.zero_grad()

        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size, ), real_label, device=device)

        output = netD(real_cpu).view(-1)

        errD_real = criterion(output, label)

        errD_real.backward()
        D_x = output.mean().item()

        noise = torch.randn(b_size, nz, 1, 1, device=device)

        fake = netG(noise)
        label.fill_(fake_label)

        output = netD(fake.detach()).view(-1)

        errD_fake = criterion(output, label)

        errD_fake.backward()
        D_G_z1 = output.mean().item()

        errD = errD_real + errD_fake

        optimizerD.step()
        # Part 2

        netG.zero_grad()
        label.fill_(real_label)

        output = netD(fake).view(-1)

        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        # Part 3
        if i % 50 == 0:
            print(
                '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                % (epoch, num_epochs, i, len(dataloader), errD.item(),
                   errG.item(), D_x, D_G_z1, D_G_z2))

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        if (iters % 500 == 0) or ((epoch == num_epochs - 1) and
                                  (i == len(dataloader) - 1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
                transforms.ToPILImage()(
                    fake[-1]).convert("RGB").save(f'epochs/{str(epoch)}.png')
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment