Created
September 12, 2017 17:18
-
-
Save ikhlestov/031e0f4e83b968cede8df1d19f3d4714 to your computer and use it in GitHub Desktop.
pytorch: weights initialization
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.autograd import Variable | |
# new way with `init` module | |
w = torch.Tensor(3, 5) | |
torch.nn.init.normal(w) | |
# work for Variables also | |
w2 = Variable(w) | |
torch.nn.init.normal(w2) | |
# old styled direct access to tensors data attribute | |
w2.data.normal_() | |
# example for some module | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
m.weight.data.normal_(0.0, 0.02) | |
elif classname.find('BatchNorm') != -1: | |
m.weight.data.normal_(1.0, 0.02) | |
m.bias.data.fill_(0) | |
# for loop approach with direct access | |
class MyModel(nn.Module): | |
def __init__(self): | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
m.weight.data.normal_(0, math.sqrt(2. / n)) | |
elif isinstance(m, nn.BatchNorm2d): | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
elif isinstance(m, nn.Linear): | |
m.bias.data.zero_() |
Hello @febriy
It's just an example function, that can be applied to the whole network and initialize corresponding layer accordingly(in this case - convolution and batchNorm). Here is an example:
net = nn.Sequential(
nn.Linear(2, 2),
nn.Conv2d(1, 20, 5),
nn.BatchNorm(20),
)
net.apply(weights_init)
In the code above Conv2d and BatchNorm layers will be reinitialized by weights_init function.
why not Linear layer too?
@ghaliahmed If you asked why I don't initialize linear layer - this is just because I use that code as an example, not as a production one.
Or you've mentioned something else?
@ikhlestov thank's for your reponse
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi there,
If you don't mind sharing, may I know what is happening in the code here:
example for some module
def weights_init(m):
classname = m.class.name
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
Thank you!