Skip to content

Instantly share code, notes, and snippets.

@rosinality
Created December 9, 2017 15:37
Show Gist options
  • Save rosinality/a96c559d84ef2b138e486acf27b5a56e to your computer and use it in GitHub Desktop.
Save rosinality/a96c559d84ef2b138e486acf27b5a56e to your computer and use it in GitHub Desktop.
Implementation of Spectral Normalization for PyTorch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
def init_conv(conv, glu=True):
init.kaiming_normal(conv.weight)
if conv.bias is not None:
conv.bias.data.zero_()
class ConvBlock(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size,
padding, stride, bn=True):
super().__init__()
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size,
stride, padding, bias=False)
self.use_bn = bn
if bn:
self.bn = nn.BatchNorm2d(out_channel)
init_conv(self.conv)
self.conv = spectral_norm(self.conv)
def forward(self, input):
out = self.conv(input)
if self.use_bn:
out = self.bn(out)
out = F.leaky_relu(out, negative_slope=0.2)
return out
class Discriminator(nn.Module):
def __init__(self, n_class=10):
super().__init__()
self.conv = nn.Sequential(ConvBlock(3, 64, [3, 3], 1, 1, bn=False),
ConvBlock(64, 64, [3, 3], 1, 2, bn=False),
ConvBlock(64, 64, [3, 3], 1, 2, bn=False),
ConvBlock(64, 128, [3, 3], 1, 2, bn=False),
ConvBlock(128, 256, [3, 3], 1, 2, bn=False),
#ConvBlock(256, 256, [3, 3], 1, 1),
ConvBlock(256, 512, [3, 3], 1, 2, bn=False))
self.linear = spectral_norm(nn.Linear(4 * 4 * 512, 1 + n_class))
init_linear(self.linear)
def forward(self, input):
out = self.conv(input)
out = self.linear(out.view(input.size(0), -1))
return F.sigmoid(out[:, 0]), out[:, 1:]
from torch.autograd import Variable
class SpectralNorm:
def __init__(self, name):
self.name = name
def compute_weight(self, module):
weight = getattr(module, self.name + '_orig')
u = getattr(module, self.name + '_u')
size = weight.size()
weight_mat = weight.contiguous().view(size[0], -1)
if weight_mat.is_cuda:
u = u.cuda()
v = weight_mat.t() @ u
v = v / v.norm()
u = weight_mat @ v
u = u / u.norm()
weight_sn = weight_mat / (u.t() @ weight_mat @ v)
weight_sn = weight_sn.view(*size)
return weight_sn, Variable(u.data)
@staticmethod
def apply(module, name):
fn = SpectralNorm(name)
weight = getattr(module, name)
del module._parameters[name]
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
input_size = weight.size(0)
u = Variable(torch.randn(input_size, 1) * 0.1, requires_grad=False)
setattr(module, name + '_u', u)
setattr(module, name, fn.compute_weight(module)[0])
module.register_forward_pre_hook(fn)
return fn
def __call__(self, module, input):
weight_sn, u = self.compute_weight(module)
setattr(module, self.name, weight_sn)
setattr(module, self.name + '_u', u)
def spectral_norm(module, name='weight'):
SpectralNorm.apply(module, name)
return module
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment