Last active
July 30, 2020 08:36
-
-
Save MLWhiz/ab4703770d91e1e2fc264c859f392f6a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Lists to keep track of progress/Losses | |
| img_list = [] | |
| G_losses = [] | |
| D_losses = [] | |
| iters = 0 | |
| # Number of training epochs | |
| num_epochs = 50 | |
| # Batch size during training | |
| batch_size = 128 | |
| print("Starting Training Loop...") | |
| # For each epoch | |
| for epoch in range(num_epochs): | |
| # For each batch in the dataloader | |
| for i, data in enumerate(dataloader, 0): | |
| ############################ | |
| # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) | |
| # Here we: | |
| # A. train the discriminator on real data | |
| # B. Create some fake images from Generator using Noise | |
| # C. train the discriminator on fake data | |
| ########################### | |
| # Training Discriminator on real data | |
| netD.zero_grad() | |
| # Format batch | |
| real_cpu = data[0].to(device) | |
| b_size = real_cpu.size(0) | |
| label = torch.full((b_size,), real_label, device=device) | |
| # Forward pass real batch through D | |
| output = netD(real_cpu).view(-1) | |
| # Calculate loss on real batch | |
| errD_real = criterion(output, label) | |
| # Calculate gradients for D in backward pass | |
| errD_real.backward() | |
| D_x = output.mean().item() | |
| ## Create a batch of fake images using generator | |
| # Generate noise to send as input to the generator | |
| noise = torch.randn(b_size, nz, 1, 1, device=device) | |
| # Generate fake image batch with G | |
| fake = netG(noise) | |
| label.fill_(fake_label) | |
| # Classify fake batch with D | |
| output = netD(fake.detach()).view(-1) | |
| # Calculate D's loss on the fake batch | |
| errD_fake = criterion(output, label) | |
| # Calculate the gradients for this batch | |
| errD_fake.backward() | |
| D_G_z1 = output.mean().item() | |
| # Add the gradients from the all-real and all-fake batches | |
| errD = errD_real + errD_fake | |
| # Update D | |
| optimizerD.step() | |
| ############################ | |
| # (2) Update G network: maximize log(D(G(z))) | |
| # Here we: | |
| # A. Find the discriminator output on Fake images | |
| # B. Calculate Generators loss based on this output. Note that the label is 1 for generator. | |
| # C. Update Generator | |
| ########################### | |
| netG.zero_grad() | |
| label.fill_(real_label) # fake labels are real for generator cost | |
| # Since we just updated D, perform another forward pass of all-fake batch through D | |
| output = netD(fake).view(-1) | |
| # Calculate G's loss based on this output | |
| errG = criterion(output, label) | |
| # Calculate gradients for G | |
| errG.backward() | |
| D_G_z2 = output.mean().item() | |
| # Update G | |
| optimizerG.step() | |
| # Output training stats every 50th Iteration in an epoch | |
| if i % 1000 == 0: | |
| print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' | |
| % (epoch, num_epochs, i, len(dataloader), | |
| errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) | |
| # Save Losses for plotting later | |
| G_losses.append(errG.item()) | |
| D_losses.append(errD.item()) | |
| # Check how the generator is doing by saving G's output on a fixed_noise vector | |
| if (iters % 250 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)): | |
| #print(iters) | |
| with torch.no_grad(): | |
| fake = netG(fixed_noise).detach().cpu() | |
| img_list.append(vutils.make_grid(fake, padding=2, normalize=True)) | |
| iters += 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment