Created
April 12, 2020 22:29
-
-
Save vfdev-5/b27ef1878930ad5fd78ab0ec3ee6c686 to your computer and use it in GitHub Desktop.
Reproduce loss=NaN in CycleGAN with torch.cuda.amp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
import torch | |
print(torch.__version__) | |
import ignite | |
print(ignite.__file__) | |
print(ignite.__version__) | |
amp_scaling_enabled = True | |
amp_autocast_enabled = True | |
from torch.cuda.amp import autocast, GradScaler | |
import random | |
seed = 17 | |
random.seed(seed) | |
_ = torch.manual_seed(seed) | |
import torch.nn as nn | |
def get_conv_inorm_relu(in_planes, out_planes, kernel_size, stride, reflection_pad=True, with_relu=True): | |
layers = [] | |
padding = (kernel_size - 1) // 2 | |
if reflection_pad: | |
layers.append(nn.ReflectionPad2d(padding=padding)) | |
padding = 0 | |
layers += [ | |
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding), | |
nn.InstanceNorm2d(out_planes, affine=False, track_running_stats=False), | |
] | |
if with_relu: | |
layers.append(nn.ReLU(inplace=True)) | |
return nn.Sequential(*layers) | |
def get_conv_transposed_inorm_relu(in_planes, out_planes, kernel_size, stride): | |
return nn.Sequential( | |
nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=1, output_padding=1), | |
nn.InstanceNorm2d(out_planes, affine=False, track_running_stats=False), | |
nn.ReLU(inplace=True) | |
) | |
class ResidualBlock(nn.Module): | |
def __init__(self, in_planes): | |
super(ResidualBlock, self).__init__() | |
self.conv1 = get_conv_inorm_relu(in_planes, in_planes, kernel_size=3, stride=1) | |
self.conv2 = get_conv_inorm_relu(in_planes, in_planes, kernel_size=3, stride=1, with_relu=False) | |
def forward(self, x): | |
residual = x | |
x = self.conv1(x) | |
x = self.conv2(x) | |
return x + residual | |
class Generator(nn.Module): | |
def __init__(self): | |
super(Generator, self).__init__() | |
self.c7s1_64 = get_conv_inorm_relu(3, 64, kernel_size=7, stride=1) | |
self.d128 = get_conv_inorm_relu(64, 128, kernel_size=3, stride=2, reflection_pad=False) | |
self.d256 = get_conv_inorm_relu(128, 256, kernel_size=3, stride=2, reflection_pad=False) | |
self.resnet9 = nn.Sequential(*[ResidualBlock(256) for i in range(9)]) | |
self.u128 = get_conv_transposed_inorm_relu(256, 128, kernel_size=3, stride=2) | |
self.u64 = get_conv_transposed_inorm_relu(128, 64, kernel_size=3, stride=2) | |
self.c7s1_3 = get_conv_inorm_relu(64, 3, kernel_size=7, stride=1, with_relu=False) | |
# Replace instance norm by tanh activation | |
self.c7s1_3[-1] = nn.Tanh() | |
def forward(self, x): | |
# Encoding | |
x = self.c7s1_64(x) | |
x = self.d128(x) | |
x = self.d256(x) | |
# 9 residual blocks | |
x = self.resnet9(x) | |
# Decoding | |
x = self.u128(x) | |
x = self.u64(x) | |
y = self.c7s1_3(x) | |
return y | |
def get_conv_inorm_lrelu(in_planes, out_planes, stride=2, negative_slope=0.2): | |
return nn.Sequential( | |
nn.Conv2d(in_planes, out_planes, kernel_size=4, stride=stride, padding=1), | |
nn.InstanceNorm2d(out_planes, affine=False, track_running_stats=False), | |
nn.LeakyReLU(negative_slope=negative_slope, inplace=True) | |
) | |
class discriminators(nn.Module): | |
def __init__(self): | |
super(discriminators, self).__init__() | |
self.c64 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1) | |
self.relu = nn.LeakyReLU(0.2, inplace=True) | |
self.c128 = get_conv_inorm_lrelu(64, 128) | |
self.c256 = get_conv_inorm_lrelu(128, 256) | |
self.c512 = get_conv_inorm_lrelu(256, 512, stride=1) | |
self.last_conv = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1) | |
def forward(self, x): | |
x = self.c64(x) | |
x = self.relu(x) | |
x = self.c128(x) | |
x = self.c256(x) | |
x = self.c512(x) | |
y = self.last_conv(x) | |
return y | |
def init_weights(module): | |
assert isinstance(module, nn.Module) | |
if hasattr(module, "weight") and module.weight is not None: | |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
if hasattr(module, "bias") and module.bias is not None: | |
torch.nn.init.constant_(module.bias, 0.0) | |
for c in module.children(): | |
init_weights(c) | |
assert torch.backends.cudnn.enabled, "NVIDIA/Apex:Amp requires cudnn backend to be enabled." | |
torch.backends.cudnn.benchmark = True | |
device = "cuda" | |
generator_A2B = Generator().to(device) | |
init_weights(generator_A2B) | |
discriminators_B = discriminators().to(device) | |
init_weights(discriminators_B) | |
generator_B2A = Generator().to(device) | |
init_weights(generator_B2A) | |
discriminators_A = discriminators().to(device) | |
init_weights(discriminators_A) | |
from itertools import chain | |
import torch.optim as optim | |
lr = 0.0002 | |
beta1 = 0.5 | |
optimizer_G = optim.Adam(chain(generator_A2B.parameters(), generator_B2A.parameters()), lr=lr, betas=(beta1, 0.999)) | |
optimizer_D = optim.Adam(chain(discriminators_A.parameters(), discriminators_B.parameters()), lr=lr, betas=(beta1, 0.999)) | |
def toggle_grad(model, on_or_off): | |
# https://github.com/ajbrock/BigGAN-PyTorch/blob/master/utils.py#L674 | |
for param in model.parameters(): | |
param.requires_grad = on_or_off | |
buffer_size = 50 | |
fake_a_buffer = [] | |
fake_b_buffer = [] | |
def buffer_insert_and_get(buffer, batch): | |
output_batch = [] | |
for b in batch: | |
b = b.unsqueeze(0) | |
# if buffer is not fully filled: | |
if len(buffer) < buffer_size: | |
output_batch.append(b) | |
buffer.append(b.cpu()) | |
elif random.uniform(0, 1) > 0.5: | |
# Add newly created image into the buffer and put ont from the buffer into the output | |
random_index = random.randint(0, buffer_size - 1) | |
output_batch.append(buffer[random_index].clone().to(device)) | |
buffer[random_index] = b.cpu() | |
else: | |
output_batch.append(b) | |
return torch.cat(output_batch, dim=0) | |
from ignite.utils import convert_tensor | |
import torch.nn.functional as F | |
lambda_value = 10.0 | |
amp_scaler = GradScaler(enabled=amp_scaling_enabled) | |
def discriminators_forward_pass(discriminators, batch_real, batch_fake, fake_buffer): | |
decision_real = discriminators(batch_real) | |
batch_fake = buffer_insert_and_get(fake_buffer, batch_fake) | |
batch_fake = batch_fake.detach() | |
decision_fake = discriminators(batch_fake) | |
return decision_real, decision_fake | |
def compute_loss_generator(batch_decision, batch_real, batch_rec, lambda_value): | |
# loss gan | |
target = torch.ones_like(batch_decision) | |
loss_gan = F.mse_loss(batch_decision, target) | |
# loss cycle | |
loss_cycle = F.l1_loss(batch_rec, batch_real) * lambda_value | |
return loss_gan + loss_cycle | |
def compute_loss_discriminators(decision_real, decision_fake): | |
# loss = mean (D_b(y) − 1)^2 + mean D_b(G(x))^2 | |
loss = F.mse_loss(decision_fake, torch.zeros_like(decision_fake)) | |
loss += F.mse_loss(decision_real, torch.ones_like(decision_real)) | |
return loss | |
def update_fn(engine, batch): | |
generator_A2B.train() | |
generator_B2A.train() | |
discriminators_A.train() | |
discriminators_B.train() | |
real_a = convert_tensor(batch['A'], device=device, non_blocking=True) | |
real_b = convert_tensor(batch['B'], device=device, non_blocking=True) | |
# Update generators | |
# Disable grads computation for the discriminators: | |
toggle_grad(discriminators_A, False) | |
toggle_grad(discriminators_B, False) | |
with autocast(enabled=amp_autocast_enabled): | |
fake_b = generator_A2B(real_a) | |
rec_a = generator_B2A(fake_b) | |
fake_a = generator_B2A(real_b) | |
rec_b = generator_A2B(fake_a) | |
decision_fake_a = discriminators_A(fake_a) | |
decision_fake_b = discriminators_B(fake_b) | |
# # Disable grads computation for the discriminators: | |
# toggle_grad(discriminators_A, False) | |
# toggle_grad(discriminators_B, False) | |
# Compute loss for generators and update generators | |
# loss_a2b = GAN loss: mean (D_b(G(x)) − 1)^2 + Forward cycle loss: || F(G(x)) - x ||_1 | |
loss_a2b = compute_loss_generator(decision_fake_b, real_a, rec_a, lambda_value) | |
# loss_b2a = GAN loss: mean (D_a(F(x)) − 1)^2 + Backward cycle loss: || G(F(y)) - y ||_1 | |
loss_b2a = compute_loss_generator(decision_fake_a, real_b, rec_b, lambda_value) | |
# total generators loss: | |
loss_generators = loss_a2b + loss_b2a | |
optimizer_G.zero_grad() | |
amp_scaler.scale(loss_generators).backward() | |
amp_scaler.step(optimizer_G) | |
amp_scaler.update() | |
decision_fake_a = rec_a = decision_fake_b = rec_b = None | |
# Enable grads computation for the discriminators: | |
toggle_grad(discriminators_A, True) | |
toggle_grad(discriminators_B, True) | |
with autocast(enabled=amp_autocast_enabled): | |
decision_real_a, decision_fake_a = discriminators_forward_pass(discriminators_A, real_a, fake_a, fake_a_buffer) | |
decision_real_b, decision_fake_b = discriminators_forward_pass(discriminators_B, real_b, fake_b, fake_b_buffer) | |
# Compute loss for discriminators and update discriminators | |
# loss_a = mean (D_a(y) − 1)^2 + mean D_a(F(x))^2 | |
loss_a = compute_loss_discriminators(decision_real_a, decision_fake_a) | |
# loss_b = mean (D_b(y) − 1)^2 + mean D_b(G(x))^2 | |
loss_b = compute_loss_discriminators(decision_real_b, decision_fake_b) | |
# total discriminators loss: | |
loss_discriminators = 0.5 * (loss_a + loss_b) | |
optimizer_D.zero_grad() | |
amp_scaler.scale(loss_discriminators).backward() | |
amp_scaler.step(optimizer_D) | |
amp_scaler.update() | |
return { | |
"loss_generators": loss_generators.item(), | |
"loss_generator_a2b": loss_a2b.item(), | |
"loss_generator_b2a": loss_b2a.item(), | |
"loss_discriminators": loss_discriminators.item(), | |
"loss_discriminators_a": loss_a.item(), | |
"loss_discriminators_b": loss_b.item(), | |
} | |
for _ in range(10): | |
real_batch = { | |
"A": 2.0 * torch.rand(6, 3, 200, 200) - 1.0, | |
"B": 2.0 * torch.rand(6, 3, 200, 200) - 1.0 | |
} | |
print("\nRun update") | |
res = update_fn(engine=None, batch=real_batch) | |
print(res) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment