-
-
Save rtqichen/b22a9c6bfc4f36e605a7b3ac1ab4122f to your computer and use it in GitHub Desktop.
## Weight norm is now added to pytorch as a pre-hook, so use that instead :) | |
import torch | |
import torch.nn as nn | |
from torch.nn import Parameter | |
from functools import wraps | |
class WeightNorm(nn.Module): | |
append_g = '_g' | |
append_v = '_v' | |
def __init__(self, module, weights): | |
super(WeightNorm, self).__init__() | |
self.module = module | |
self.weights = weights | |
self._reset() | |
def _reset(self): | |
for name_w in self.weights: | |
w = getattr(self.module, name_w) | |
# construct g,v such that w = g/||v|| * v | |
g = torch.norm(w) | |
v = w/g.expand_as(w) | |
g = Parameter(g.data) | |
v = Parameter(v.data) | |
name_g = name_w + self.append_g | |
name_v = name_w + self.append_v | |
# remove w from parameter list | |
del self.module._parameters[name_w] | |
# add g and v as new parameters | |
self.module.register_parameter(name_g, g) | |
self.module.register_parameter(name_v, v) | |
def _setweights(self): | |
for name_w in self.weights: | |
name_g = name_w + self.append_g | |
name_v = name_w + self.append_v | |
g = getattr(self.module, name_g) | |
v = getattr(self.module, name_v) | |
w = v*(g/torch.norm(v)).expand_as(v) | |
setattr(self.module, name_w, w) | |
def forward(self, *args): | |
self._setweights() | |
return self.module.forward(*args) | |
############################################################## | |
## An older version using a python decorator but might be buggy. | |
## Does not work when the module is replicated (e.g. nn.DataParallel) | |
def _decorate(forward, module, name, name_g, name_v): | |
@wraps(forward) | |
def decorated_forward(*args, **kwargs): | |
g = module.__getattr__(name_g) | |
v = module.__getattr__(name_v) | |
w = v*(g/torch.norm(v)).expand_as(v) | |
module.__setattr__(name, w) | |
return forward(*args, **kwargs) | |
return decorated_forward | |
def weight_norm(module, name): | |
param = module.__getattr__(name) | |
# construct g,v such that w = g/||v|| * v | |
g = torch.norm(param) | |
v = param/g.expand_as(param) | |
g = Parameter(g.data) | |
v = Parameter(v.data) | |
name_g = name + '_g' | |
name_v = name + '_v' | |
# remove w from parameter list | |
del module._parameters[name] | |
# add g and v as new parameters | |
module.register_parameter(name_g, g) | |
module.register_parameter(name_v, v) | |
# construct w every time before forward is called | |
module.forward = _decorate(module.forward, module, name, name_g, name_v) | |
return module |
import torch | |
import torch.nn as nn | |
from pytorch_weight_norm import WeightNorm | |
x = torch.autograd.Variable(torch.randn(5,10,30,30)) | |
m = nn.ConvTranspose2d(10,20,3) | |
y = m(x) | |
print(m._parameters.keys()) | |
# odict_keys(['weight', 'bias']) | |
m = WeightNorm(m, ['weight']) | |
y_wn = m(x) | |
print(m.module._parameters.keys()) | |
# odict_keys(['bias', 'weight_g', 'weight_v']) | |
print(torch.norm(y-y_wn).data[0]) | |
# 1.3324766769073904e-05 (not important to get this smaller) | |
## can also use within sequential | |
## and is also stackable | |
net = nn.Sequential( | |
WeightNorm(nn.Linear(30,10), ['weight']), | |
nn.ReLU(), | |
WeightNorm(nn.Linear(10,20), ['weight', 'bias']), | |
) |
This breaks printing of modules for conv layers. A quick fix is to add
if name_w == 'bias':
self.module.bias = None
to _reset
EDIT: Thanks for sharing your code :)
@rtqichen could not find the data dependent init. I thought it was important to the weight norm. Isn't it?
How to incorporate the Pytorch 0.2.0 support of Weight Normalization in new RNN projects?
http://pytorch.org/docs/master/nn.html#torch.nn.utils.weight_norm
hello, thanks for sharing this elegant implementation. Where could I find the newer updated version?
Thanks!
@rtqichen Thanks for contribution for this code.
@ all @greaber @Smerity @ypxie @ hanzhanggit
Hi everyone who read this post. I have some questions regarding to weight_norm. It would be great if you can help.
I tried to implement the weight_norm for each convolution and linear layer (check the code here https://github.com/xwuaustin/weight_norm/blob/master/cifar10_tutorial_weightNorm.py ). However, the training loss in CIFAR-10 seems no difference to the original setting (see the picture below) at the first 10 epochs (6 iterations equal to 1 epoch).
Now questions:
1. Is there something wrong with the code I modified? I used the code from cifar10_tutorial in pytorch. All I did is to add the wieghtNorm at each layer.
import torch.nn.utils.weight_norm as weightNorm
class Net(nn.Module):
def init(self):
super(Net, self).init()
### we use weight normalization after each convolutions and linear transfrom
self.conv1 = weightNorm(nn.Conv2d(3, 6, 5),name = "weight")
#print (self.conv1._parameters.keys())
self.pool = nn.MaxPool2d(2, 2)
self.conv2 =weightNorm(nn.Conv2d(6, 16, 5),name = "weight")
self.fc1 = weightNorm(nn.Linear(16 * 5 * 5, 120),name = "weight")
self.fc2 = weightNorm(nn.Linear(120, 84),name = "weight")
self.fc3 = weightNorm(nn.Linear(84, 10),name = "weight")
2 Is the update of the weights and bias, namely 'weight_g', 'weight_v', using the formulation:
3. Can we do the initialization as the paper suggested?
Thanks. Looking for your responds. :)
Excellent work to solve weight_norm(...) in deep copy problem!
Thank you
@rtqichen - thanks for posting the original and the updated version! Even if only meant for a friend it was certainly appreciated it ^_^
@greaber - I hadn't even thought of that! I presume you're not using the cuDNN LSTM and instead are using an LSTM cell timestep by timestep? If you were using the cuDNN LSTM it'd avoid this issue as it should only be calling the forward once per set of input I think..?