Last active
May 11, 2023 06:58
-
-
Save rtqichen/b22a9c6bfc4f36e605a7b3ac1ab4122f to your computer and use it in GitHub Desktop.
Pytorch weight normalization - works for all nn.Module (probably)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Weight norm is now added to pytorch as a pre-hook, so use that instead :) | |
import torch | |
import torch.nn as nn | |
from torch.nn import Parameter | |
from functools import wraps | |
class WeightNorm(nn.Module): | |
append_g = '_g' | |
append_v = '_v' | |
def __init__(self, module, weights): | |
super(WeightNorm, self).__init__() | |
self.module = module | |
self.weights = weights | |
self._reset() | |
def _reset(self): | |
for name_w in self.weights: | |
w = getattr(self.module, name_w) | |
# construct g,v such that w = g/||v|| * v | |
g = torch.norm(w) | |
v = w/g.expand_as(w) | |
g = Parameter(g.data) | |
v = Parameter(v.data) | |
name_g = name_w + self.append_g | |
name_v = name_w + self.append_v | |
# remove w from parameter list | |
del self.module._parameters[name_w] | |
# add g and v as new parameters | |
self.module.register_parameter(name_g, g) | |
self.module.register_parameter(name_v, v) | |
def _setweights(self): | |
for name_w in self.weights: | |
name_g = name_w + self.append_g | |
name_v = name_w + self.append_v | |
g = getattr(self.module, name_g) | |
v = getattr(self.module, name_v) | |
w = v*(g/torch.norm(v)).expand_as(v) | |
setattr(self.module, name_w, w) | |
def forward(self, *args): | |
self._setweights() | |
return self.module.forward(*args) | |
############################################################## | |
## An older version using a python decorator but might be buggy. | |
## Does not work when the module is replicated (e.g. nn.DataParallel) | |
def _decorate(forward, module, name, name_g, name_v): | |
@wraps(forward) | |
def decorated_forward(*args, **kwargs): | |
g = module.__getattr__(name_g) | |
v = module.__getattr__(name_v) | |
w = v*(g/torch.norm(v)).expand_as(v) | |
module.__setattr__(name, w) | |
return forward(*args, **kwargs) | |
return decorated_forward | |
def weight_norm(module, name): | |
param = module.__getattr__(name) | |
# construct g,v such that w = g/||v|| * v | |
g = torch.norm(param) | |
v = param/g.expand_as(param) | |
g = Parameter(g.data) | |
v = Parameter(v.data) | |
name_g = name + '_g' | |
name_v = name + '_v' | |
# remove w from parameter list | |
del module._parameters[name] | |
# add g and v as new parameters | |
module.register_parameter(name_g, g) | |
module.register_parameter(name_v, v) | |
# construct w every time before forward is called | |
module.forward = _decorate(module.forward, module, name, name_g, name_v) | |
return module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from pytorch_weight_norm import WeightNorm | |
x = torch.autograd.Variable(torch.randn(5,10,30,30)) | |
m = nn.ConvTranspose2d(10,20,3) | |
y = m(x) | |
print(m._parameters.keys()) | |
# odict_keys(['weight', 'bias']) | |
m = WeightNorm(m, ['weight']) | |
y_wn = m(x) | |
print(m.module._parameters.keys()) | |
# odict_keys(['bias', 'weight_g', 'weight_v']) | |
print(torch.norm(y-y_wn).data[0]) | |
# 1.3324766769073904e-05 (not important to get this smaller) | |
## can also use within sequential | |
## and is also stackable | |
net = nn.Sequential( | |
WeightNorm(nn.Linear(30,10), ['weight']), | |
nn.ReLU(), | |
WeightNorm(nn.Linear(10,20), ['weight', 'bias']), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Excellent work to solve weight_norm(...) in deep copy problem!
Thank you