Created
October 3, 2017 11:54
-
-
Save thomwolf/eea8989cab5ac49919df95f6f1309d80 to your computer and use it in GitHub Desktop.
Simple way to reproduce Keras default initialisation in a typical pyTorch NLP model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def init_weights(self): | |
""" | |
Here we reproduce Keras default initialization weights to initialize Embeddings/LSTM weights | |
""" | |
ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name) | |
hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name) | |
b = (param.data for name, param in self.named_parameters() if 'bias' in name) | |
nn.init.uniform(self.embed.weight.data, a=-0.5, b=0.5) | |
for t in ih: | |
nn.init.xavier_uniform(t) | |
for t in hh: | |
nn.init.orthogonal(t) | |
for t in b: | |
nn.init.constant(t, 0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment