Created
December 12, 2018 12:59
-
-
Save RomanSteinberg/9ca64be01ff8c8d02a225bd56c41fb5d to your computer and use it in GitHub Desktop.
Freeing buffers strange behavior
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Description: | |
# This script is a minimal example of a freeing buffer strange behavior. Originally it contains error diagnosed | |
# by PyTorch: | |
# "RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. | |
# Specify retain_graph=True when calling backward the first time." | |
# | |
# One can find statements which can be changed to remove error. | |
import torch | |
from torch import nn, cuda | |
from torch.autograd import Variable, grad | |
class Flatten(nn.Module): | |
def forward(self, x): | |
return x.view(x.size()[0], -1) | |
class BrokenBlock(nn.Module): | |
def __init__(self, dim): | |
super(BrokenBlock, self).__init__() | |
self.conv_block = nn.Sequential(*[nn.InstanceNorm2d(dim, affine=False), | |
nn.ReLU(inplace=True)]) # change inplace=False and error disappears | |
def forward(self, x): | |
return self.conv_block(x) | |
class Di(nn.Module): | |
def __init__(self, input_shape): | |
super(Di, self).__init__() | |
input_nc, h, w = input_shape | |
sequence = [BrokenBlock(input_nc), | |
BrokenBlock(input_nc), # comment this line and error disappears | |
Flatten(), | |
nn.Linear(input_nc * h * w, 1)] | |
self.model = nn.Sequential(*sequence) | |
def forward(self, input, parallel_mode=True): | |
if isinstance(input.data, cuda.FloatTensor) and parallel_mode: | |
return nn.parallel.data_parallel(self.model, input, [0]) | |
else: | |
return self.model(input) | |
class BrokenModel(): | |
def __init__(self): | |
self.batch_size = 1 | |
self.netD_A = Di((3, 256, 256)).cuda(device=0) | |
self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters()) | |
self.real_A = Variable(torch.ones((1, 3, 256, 256)).cuda()) | |
self.real_B = Variable(torch.ones((1, 3, 256, 256)).cuda()) | |
def backward_D_wgan(self, netD, real, fake): | |
loss_D_real = -netD.forward(real, parallel_mode=True).mean() # change parallel_mode=False and error disappears | |
loss_D_fake = netD.forward(fake, parallel_mode=True).mean() # change parallel_mode=False and error disappears | |
gradient_penalty = self.__calc_gradient_penalty(netD, real, fake) | |
# separate Di backward pass for 3 parts | |
gradient_penalty.backward() # comment this line and error disappears | |
loss_D_real.backward() | |
loss_D_fake.backward() | |
loss_D = loss_D_fake - loss_D_real + gradient_penalty | |
return loss_D | |
def __calc_gradient_penalty(self, netD, real, fake): | |
alpha = torch.rand(self.batch_size, 1, 1, 1).expand_as(real).cuda() | |
interpolated = alpha * real.data + (1 - alpha) * fake.data | |
interpolated = Variable(interpolated, requires_grad=True).cuda() | |
# Calculate probability of interpolated examples | |
prob_interpolated = netD.forward(interpolated, parallel_mode=True) # change parallel_mode=False and error disappears | |
# Calculate gradients of probabilities with respect to examples | |
gradients = grad(outputs=prob_interpolated, inputs=interpolated, | |
grad_outputs=torch.ones(prob_interpolated.size()).cuda(), | |
create_graph=True, retain_graph=True)[0] | |
gradients_flatten = gradients.view(self.batch_size, -1) | |
gradients_norm = torch.sqrt(torch.sum(gradients_flatten ** 2, dim=1) + 1e-12) | |
return ((gradients_norm - 1) ** 2).mean() | |
def optimize_parameters(self): | |
self.optimizer_D_A.zero_grad() | |
self.backward_D_wgan(self.netD_A, self.real_A, self.real_B) | |
self.optimizer_D_A.step() | |
model = BrokenModel() | |
model.optimize_parameters() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment