Skip to content

Instantly share code, notes, and snippets.

@enijkamp
Created August 14, 2018 22:58
Show Gist options
  • Save enijkamp/21ffeac6e2181c686aa4cf9dc913ebf7 to your computer and use it in GitHub Desktop.
Save enijkamp/21ffeac6e2181c686aa4cf9dc913ebf7 to your computer and use it in GitHub Desktop.
gradient penalty
def calc_gradient_penalty(netD, real_data, fake_data, LAMBDA=10, BATCH_SIZE=128, HW=[64, 64]):
alpha = torch.rand(BATCH_SIZE, 1)
alpha = alpha.expand(BATCH_SIZE, int(real_data.nelement() / BATCH_SIZE)).contiguous()
alpha = alpha.view(BATCH_SIZE, 3, HW[0], HW[1])
alpha = alpha.to(device)
fake_data = fake_data.view(BATCH_SIZE, 3, HW[0], HW[1])
interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach())
interpolates = interpolates.to(device)
interpolates.requires_grad_(True)
disc_interpolates = netD(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).to(device), create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = LAMBDA * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment