Skip to content

Instantly share code, notes, and snippets.

@rtqichen
Last active May 11, 2023 06:58
Show Gist options
  • Save rtqichen/b22a9c6bfc4f36e605a7b3ac1ab4122f to your computer and use it in GitHub Desktop.
Save rtqichen/b22a9c6bfc4f36e605a7b3ac1ab4122f to your computer and use it in GitHub Desktop.
Pytorch weight normalization - works for all nn.Module (probably)
## Weight norm is now added to pytorch as a pre-hook, so use that instead :)
import torch
import torch.nn as nn
from torch.nn import Parameter
from functools import wraps
class WeightNorm(nn.Module):
append_g = '_g'
append_v = '_v'
def __init__(self, module, weights):
super(WeightNorm, self).__init__()
self.module = module
self.weights = weights
self._reset()
def _reset(self):
for name_w in self.weights:
w = getattr(self.module, name_w)
# construct g,v such that w = g/||v|| * v
g = torch.norm(w)
v = w/g.expand_as(w)
g = Parameter(g.data)
v = Parameter(v.data)
name_g = name_w + self.append_g
name_v = name_w + self.append_v
# remove w from parameter list
del self.module._parameters[name_w]
# add g and v as new parameters
self.module.register_parameter(name_g, g)
self.module.register_parameter(name_v, v)
def _setweights(self):
for name_w in self.weights:
name_g = name_w + self.append_g
name_v = name_w + self.append_v
g = getattr(self.module, name_g)
v = getattr(self.module, name_v)
w = v*(g/torch.norm(v)).expand_as(v)
setattr(self.module, name_w, w)
def forward(self, *args):
self._setweights()
return self.module.forward(*args)
##############################################################
## An older version using a python decorator but might be buggy.
## Does not work when the module is replicated (e.g. nn.DataParallel)
def _decorate(forward, module, name, name_g, name_v):
@wraps(forward)
def decorated_forward(*args, **kwargs):
g = module.__getattr__(name_g)
v = module.__getattr__(name_v)
w = v*(g/torch.norm(v)).expand_as(v)
module.__setattr__(name, w)
return forward(*args, **kwargs)
return decorated_forward
def weight_norm(module, name):
param = module.__getattr__(name)
# construct g,v such that w = g/||v|| * v
g = torch.norm(param)
v = param/g.expand_as(param)
g = Parameter(g.data)
v = Parameter(v.data)
name_g = name + '_g'
name_v = name + '_v'
# remove w from parameter list
del module._parameters[name]
# add g and v as new parameters
module.register_parameter(name_g, g)
module.register_parameter(name_v, v)
# construct w every time before forward is called
module.forward = _decorate(module.forward, module, name, name_g, name_v)
return module
import torch
import torch.nn as nn
from pytorch_weight_norm import WeightNorm
x = torch.autograd.Variable(torch.randn(5,10,30,30))
m = nn.ConvTranspose2d(10,20,3)
y = m(x)
print(m._parameters.keys())
# odict_keys(['weight', 'bias'])
m = WeightNorm(m, ['weight'])
y_wn = m(x)
print(m.module._parameters.keys())
# odict_keys(['bias', 'weight_g', 'weight_v'])
print(torch.norm(y-y_wn).data[0])
# 1.3324766769073904e-05 (not important to get this smaller)
## can also use within sequential
## and is also stackable
net = nn.Sequential(
WeightNorm(nn.Linear(30,10), ['weight']),
nn.ReLU(),
WeightNorm(nn.Linear(10,20), ['weight', 'bias']),
)
@foowaa
Copy link

foowaa commented Oct 21, 2021

Excellent work to solve weight_norm(...) in deep copy problem!
Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment