Skip to content

Instantly share code, notes, and snippets.

@Kajiyu
Created August 9, 2018 09:56
Show Gist options
  • Save Kajiyu/9babf9bd011bbb315cc82f8324c0307c to your computer and use it in GitHub Desktop.
Save Kajiyu/9babf9bd011bbb315cc82f8324c0307c to your computer and use it in GitHub Desktop.
Pytorch Implementation of "Spectral Normalization" for Vanilla RNN.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Most of this code is borrowed by niffler92's project.
https://github.com/niffler92/SNGAN
"""
import torch
import torch.nn as nn
from torch.autograd import Variable
class RNNSpectralNorm(nn.Module):
def __init__(self, module, niter=5):
super().__init__()
self.module = module
self.niter = niter
self.init_params(module)
@staticmethod
def init_params(module):
for i in range(module.num_layers):
ihw = getattr(module, 'weight_ih_l'+str(i))
height = ihw.size(0)
width = ihw.view(ihw.size(0), -1).shape[-1]
u = nn.Parameter(torch.randn(height, 1), requires_grad=False)
v = nn.Parameter(torch.randn(1, width), requires_grad=False)
module.register_buffer('u_ih'+str(i), u)
module.register_buffer('v_ih'+str(i), v)
module.register_buffer('w_ih'+str(i), ihw)
for i in range(module.num_layers):
hhw = getattr(module, 'weight_hh_l'+str(i))
height = hhw.size(0)
width = hhw.view(hhw.size(0), -1).shape[-1]
u = nn.Parameter(torch.randn(height, 1), requires_grad=False)
v = nn.Parameter(torch.randn(1, width), requires_grad=False)
module.register_buffer('u_hh'+str(i), u)
module.register_buffer('v_hh'+str(i), v)
module.register_buffer('w_hh'+str(i), hhw)
@staticmethod
def update_params(module, niter):
buffers = module._buffers
for i in range(module.num_layers):
u_ih = buffers['u_ih'+str(i)]
v_ih = buffers['v_ih'+str(i)]
w_ih = getattr(module, 'weight_ih_l'+str(i))
height = w_ih.size(0)
for i in range(niter):
v_ih = w_ih.view(height, -1).t() @ u_ih
v_ih /= (v_ih.norm(p=2) + 1e-12)
u_ih = w_ih.view(height, -1) @ v_ih
u_ih /= (u_ih.norm(p=2) + 1e-12)
w_ih.data /= (u_ih.t() @ w_ih.view(height, -1) @ v_ih).data # Spectral normalization
# setattr(module, 'weight_ih_l'+str(i), w_ih)
for i in range(module.num_layers):
u_hh = buffers['u_hh'+str(i)]
v_hh = buffers['v_hh'+str(i)]
w_hh = getattr(module, 'weight_hh_l'+str(i))
height = w_hh.size(0)
for i in range(niter):
v_hh = w_hh.view(height, -1).t() @ u_hh
v_hh /= (v_hh.norm(p=2) + 1e-12)
u_hh = w_hh.view(height, -1) @ v_hh
u_hh /= (u_hh.norm(p=2) + 1e-12)
w_hh.data /= (u_hh.t() @ w_hh.view(height, -1) @ v_hh).data # Spectral normalization
# setattr(module, 'weight_hh_l'+str(i), w_hh)
def forward(self, x, chx):
self.update_params(self.module, self.niter)
return self.module(x, chx)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment