Created
December 3, 2022 08:31
-
-
Save dyigitpolat/9a2f84ae891e7dfd86127df123bb19aa to your computer and use it in GitHub Desktop.
batch norm fusing by openai chat bot
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def fuse_linear_bn(linear, bn): | |
# Get the weight and bias of the linear layer | |
weight, bias = linear.weight, linear.bias | |
# Get the running mean and variance of the batch norm layer | |
running_mean, running_var = bn.running_mean | |
# Compute the scale and shift parameters for the batch norm layer | |
# using the weight and bias of the linear layer | |
scale = weight / torch.sqrt(running_var + bn.eps) | |
shift = bias - running_mean * scale | |
# Replace the weight and bias of the linear layer with the scale and shift | |
# parameters of the batch norm layer | |
linear.weight = torch.nn.Parameter(scale) | |
linear.bias = torch.nn.Parameter(shift) | |
# Return the fused linear layer | |
return linear |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment