Skip to content

Instantly share code, notes, and snippets.

@ikhlestov
Created September 12, 2017 17:18
Show Gist options
  • Save ikhlestov/031e0f4e83b968cede8df1d19f3d4714 to your computer and use it in GitHub Desktop.
Save ikhlestov/031e0f4e83b968cede8df1d19f3d4714 to your computer and use it in GitHub Desktop.
pytorch: weights initialization
import torch
from torch.autograd import Variable
# new way with `init` module
w = torch.Tensor(3, 5)
torch.nn.init.normal(w)
# work for Variables also
w2 = Variable(w)
torch.nn.init.normal(w2)
# old styled direct access to tensors data attribute
w2.data.normal_()
# example for some module
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
# for loop approach with direct access
class MyModel(nn.Module):
def __init__(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
@ikhlestov
Copy link
Author

@ghaliahmed If you asked why I don't initialize linear layer - this is just because I use that code as an example, not as a production one.
Or you've mentioned something else?

@ahmedghali
Copy link

@ikhlestov thank's for your reponse

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment