Created
May 16, 2023 15:57
-
-
Save Kautenja/99757c6dd428cf014cad248c2dce57f6 to your computer and use it in GitHub Desktop.
A PyTorch method for fusing normalization statistics directly into a convolutional layer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn import Conv2d | |
@torch.no_grad() | |
def fuse_normalize_into_conv_2d(conv: Conv2d, mean: torch.Tensor, std: torch.Tensor) -> Conv2d: | |
""" | |
Fuse normalization statistics into a convolutional layer. | |
Args: | |
conv: The convolutional layer to fuse the norm layer into. | |
mean: The mean value vector with shape [in_channels]. | |
std: The standard deviation vector with shape [in_channels]. | |
Returns: | |
The convolutional layer. | |
Details: | |
This fusion is based on the following re-write of norm+conv. First, | |
we can fuse the scale (standard deviation) into the convolutional | |
weights using the associate property of multiplication. | |
$$ | |
\frac{(x-a)}{b} * w = (x-a) * \frac{w}{b} | |
$$ | |
Next, we can use the distributive property to re-write the subtraction | |
of the mean in such a way that it can be lumped into a single bias term. | |
$$ | |
(x-a) * \frac{w}{b} + c = x * \frac{w}{b} + (c - a * \frac{w}{b}) | |
$$ | |
Ultimately, this means the weight of the layer gets scaled: | |
$$ | |
w \gets \frac{w}{b} | |
$$ | |
and from the bias term we remove the convolution of the mean with the | |
(scaled) weight. | |
$$ | |
c \gets c - a * \frac{w}{b} | |
$$ | |
""" | |
# Ensure the mean and standard deviation are in [N, C, H, W] format. | |
mean = mean.view(1, 3, 1, 1) | |
std = std.view(1, 3, 1, 1) | |
# Fuse the standard deviation into the convolutional weight. | |
conv.weight[:] = conv.weight / std | |
_, _, H, W = conv.weight.shape | |
yc = H // 2 | |
xc = W // 2 | |
# Fuse the mean into the convolutional weight. | |
mean = mean.expand(1, 3, H+1, W+1) | |
offset = torch.conv2d(mean, conv.weight, padding='same')[:, :, yc:yc+1, xc:xc+1].squeeze() | |
if conv.bias is None: # No bias, assign one directly. | |
conv.bias = nn.Parameter(-offset) | |
else: # Adjust the existing bias term. | |
conv.bias[:] = conv.bias - offset | |
return conv | |
# Explicitly define the outward facing API of this module. | |
__all__ = [fuse_normalize_into_conv_2d.__name__] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment