img_list = []
G_losses = []
D_losses = []
iters = 0
for epoch in range(num_epochs):
for i, data in enumerate(dataloader, 0):
# Part 1
netD.zero_grad()
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size, ), real_label, device=device)
output = netD(real_cpu).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
D_x = output.mean().item()
noise = torch.randn(b_size, nz, 1, 1, device=device)
fake = netG(noise)
label.fill_(fake_label)
output = netD(fake.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
optimizerD.step()
# Part 2
netG.zero_grad()
label.fill_(real_label)
output = netD(fake).view(-1)
errG = criterion(output, label)
errG.backward()
D_G_z2 = output.mean().item()
optimizerG.step()
# Part 3
if i % 50 == 0:
print(
'[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader), errD.item(),
errG.item(), D_x, D_G_z1, D_G_z2))
G_losses.append(errG.item())
D_losses.append(errD.item())
if (iters % 500 == 0) or ((epoch == num_epochs - 1) and
(i == len(dataloader) - 1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
transforms.ToPILImage()(
fake[-1]).convert("RGB").save(f'epochs/{str(epoch)}.png')
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
Created
February 13, 2020 06:46
-
-
Save SubhadityaMukherjee/a0f3ca4d7945e811b69b5f8c776d1533 to your computer and use it in GitHub Desktop.
train
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment