-
-
Save rtqichen/b22a9c6bfc4f36e605a7b3ac1ab4122f to your computer and use it in GitHub Desktop.
## Weight norm is now added to pytorch as a pre-hook, so use that instead :) | |
import torch | |
import torch.nn as nn | |
from torch.nn import Parameter | |
from functools import wraps | |
class WeightNorm(nn.Module): | |
append_g = '_g' | |
append_v = '_v' | |
def __init__(self, module, weights): | |
super(WeightNorm, self).__init__() | |
self.module = module | |
self.weights = weights | |
self._reset() | |
def _reset(self): | |
for name_w in self.weights: | |
w = getattr(self.module, name_w) | |
# construct g,v such that w = g/||v|| * v | |
g = torch.norm(w) | |
v = w/g.expand_as(w) | |
g = Parameter(g.data) | |
v = Parameter(v.data) | |
name_g = name_w + self.append_g | |
name_v = name_w + self.append_v | |
# remove w from parameter list | |
del self.module._parameters[name_w] | |
# add g and v as new parameters | |
self.module.register_parameter(name_g, g) | |
self.module.register_parameter(name_v, v) | |
def _setweights(self): | |
for name_w in self.weights: | |
name_g = name_w + self.append_g | |
name_v = name_w + self.append_v | |
g = getattr(self.module, name_g) | |
v = getattr(self.module, name_v) | |
w = v*(g/torch.norm(v)).expand_as(v) | |
setattr(self.module, name_w, w) | |
def forward(self, *args): | |
self._setweights() | |
return self.module.forward(*args) | |
############################################################## | |
## An older version using a python decorator but might be buggy. | |
## Does not work when the module is replicated (e.g. nn.DataParallel) | |
def _decorate(forward, module, name, name_g, name_v): | |
@wraps(forward) | |
def decorated_forward(*args, **kwargs): | |
g = module.__getattr__(name_g) | |
v = module.__getattr__(name_v) | |
w = v*(g/torch.norm(v)).expand_as(v) | |
module.__setattr__(name, w) | |
return forward(*args, **kwargs) | |
return decorated_forward | |
def weight_norm(module, name): | |
param = module.__getattr__(name) | |
# construct g,v such that w = g/||v|| * v | |
g = torch.norm(param) | |
v = param/g.expand_as(param) | |
g = Parameter(g.data) | |
v = Parameter(v.data) | |
name_g = name + '_g' | |
name_v = name + '_v' | |
# remove w from parameter list | |
del module._parameters[name] | |
# add g and v as new parameters | |
module.register_parameter(name_g, g) | |
module.register_parameter(name_v, v) | |
# construct w every time before forward is called | |
module.forward = _decorate(module.forward, module, name, name_g, name_v) | |
return module |
import torch | |
import torch.nn as nn | |
from pytorch_weight_norm import WeightNorm | |
x = torch.autograd.Variable(torch.randn(5,10,30,30)) | |
m = nn.ConvTranspose2d(10,20,3) | |
y = m(x) | |
print(m._parameters.keys()) | |
# odict_keys(['weight', 'bias']) | |
m = WeightNorm(m, ['weight']) | |
y_wn = m(x) | |
print(m.module._parameters.keys()) | |
# odict_keys(['bias', 'weight_g', 'weight_v']) | |
print(torch.norm(y-y_wn).data[0]) | |
# 1.3324766769073904e-05 (not important to get this smaller) | |
## can also use within sequential | |
## and is also stackable | |
net = nn.Sequential( | |
WeightNorm(nn.Linear(30,10), ['weight']), | |
nn.ReLU(), | |
WeightNorm(nn.Linear(10,20), ['weight', 'bias']), | |
) |
@rtqichen could not find the data dependent init. I thought it was important to the weight norm. Isn't it?
How to incorporate the Pytorch 0.2.0 support of Weight Normalization in new RNN projects?
http://pytorch.org/docs/master/nn.html#torch.nn.utils.weight_norm
hello, thanks for sharing this elegant implementation. Where could I find the newer updated version?
Thanks!
@rtqichen Thanks for contribution for this code.
@ all @greaber @Smerity @ypxie @ hanzhanggit
Hi everyone who read this post. I have some questions regarding to weight_norm. It would be great if you can help.
I tried to implement the weight_norm for each convolution and linear layer (check the code here https://github.com/xwuaustin/weight_norm/blob/master/cifar10_tutorial_weightNorm.py ). However, the training loss in CIFAR-10 seems no difference to the original setting (see the picture below) at the first 10 epochs (6 iterations equal to 1 epoch).
Now questions:
1. Is there something wrong with the code I modified? I used the code from cifar10_tutorial in pytorch. All I did is to add the wieghtNorm at each layer.
import torch.nn.utils.weight_norm as weightNorm
class Net(nn.Module):
def init(self):
super(Net, self).init()
### we use weight normalization after each convolutions and linear transfrom
self.conv1 = weightNorm(nn.Conv2d(3, 6, 5),name = "weight")
#print (self.conv1._parameters.keys())
self.pool = nn.MaxPool2d(2, 2)
self.conv2 =weightNorm(nn.Conv2d(6, 16, 5),name = "weight")
self.fc1 = weightNorm(nn.Linear(16 * 5 * 5, 120),name = "weight")
self.fc2 = weightNorm(nn.Linear(120, 84),name = "weight")
self.fc3 = weightNorm(nn.Linear(84, 10),name = "weight")
2 Is the update of the weights and bias, namely 'weight_g', 'weight_v', using the formulation:
3. Can we do the initialization as the paper suggested?
Thanks. Looking for your responds. :)
Excellent work to solve weight_norm(...) in deep copy problem!
Thank you
This breaks printing of modules for conv layers. A quick fix is to add
to
_reset
EDIT: Thanks for sharing your code :)