Skip to content

Instantly share code, notes, and snippets.

@123epsilon
Created April 27, 2021 01:50
Show Gist options
  • Save 123epsilon/7e1e414504857e85feef0dd5fb6486bd to your computer and use it in GitHub Desktop.
Save 123epsilon/7e1e414504857e85feef0dd5fb6486bd to your computer and use it in GitHub Desktop.
Simple GAN Training
z_dim = 10
num_epochs = 30000
batch_size = 32
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
criterion = nn.BCELoss()
gen = Generator(z_dim=z_dim, hidden_dim=28, n_layers=3, out_dim=2).to(device)
disc = Discriminator(input_dim=2, hidden_dim=28, n_layers=3).to(device)
optimizerD = optim.Adam(disc.parameters())
optimizerG = optim.Adam(gen.parameters())
fixed_noise = torch.randn(128, z_dim, device=device)
real_label = 1
fake_label = 0
#Main Training Loop
print("Training...")
print(device)
for epoch in range(num_epochs):
#max log(D(x)) + log(1 - D(G(z)))
#train on real points
disc.zero_grad()
real_points = torch.tensor( sample_dist(n=batch_size, r=r, dist=distribution, mode=mode) ).float()
label = torch.full( (batch_size,), real_label, dtype=torch.float, device=device ).view(-1)
output = disc(real_points).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
#train on fake points
noise = torch.randn(batch_size, z_dim, device=device)
fake_points = gen(noise)
label.fill_(fake_label)
output = disc(fake_points.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
errD = errD_real + errD_fake
optimizerD.step()
#max log(D(G(z)))
#Train Generator Discriminator Outputs
gen.zero_grad()
label.fill_(real_label)
output = disc(fake_points).view(-1)
errG = criterion(output, label)
errG.backward()
optimizerG.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment