Skip to content

Instantly share code, notes, and snippets.

@rtqichen
rtqichen / pytorch_weight_norm.py
Last active May 11, 2023 06:58
Pytorch weight normalization - works for all nn.Module (probably)
## Weight norm is now added to pytorch as a pre-hook, so use that instead :)
import torch
import torch.nn as nn
from torch.nn import Parameter
from functools import wraps
class WeightNorm(nn.Module):
append_g = '_g'
append_v = '_v'