Last active
November 22, 2024 07:17
-
Star
(153)
You must be signed in to star a gist -
Fork
(27)
You must be signed in to fork a gist
Revisions
-
jeasinema revised this gist
Aug 29, 2018 . 1 changed file with 28 additions and 28 deletions.There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -13,65 +13,65 @@ def weight_init(m): model.apply(weight_init) ''' if isinstance(m, nn.Conv1d): init.normal_(m.weight.data) if m.bias is not None: init.normal_(m.bias.data) elif isinstance(m, nn.Conv2d): init.xavier_normal_(m.weight.data) if m.bias is not None: init.normal_(m.bias.data) elif isinstance(m, nn.Conv3d): init.xavier_normal_(m.weight.data) if m.bias is not None: init.normal_(m.bias.data) elif isinstance(m, nn.ConvTranspose1d): init.normal_(m.weight.data) if m.bias is not None: init.normal_(m.bias.data) elif isinstance(m, nn.ConvTranspose2d): init.xavier_normal_(m.weight.data) if m.bias is not None: init.normal_(m.bias.data) elif isinstance(m, nn.ConvTranspose3d): init.xavier_normal_(m.weight.data) if m.bias is not None: init.normal_(m.bias.data) elif isinstance(m, nn.BatchNorm1d): init.normal_(m.weight.data, mean=1, std=0.02) init.constant_(m.bias.data, 0) elif isinstance(m, nn.BatchNorm2d): init.normal_(m.weight.data, mean=1, std=0.02) init.constant_(m.bias.data, 0) elif isinstance(m, nn.BatchNorm3d): init.normal_(m.weight.data, mean=1, std=0.02) init.constant_(m.bias.data, 0) elif isinstance(m, nn.Linear): init.xavier_normal_(m.weight.data) init.normal_(m.bias.data) elif isinstance(m, nn.LSTM): for param in m.parameters(): if len(param.shape) >= 2: init.orthogonal_(param.data) else: init.normal_(param.data) elif isinstance(m, nn.LSTMCell): for param in m.parameters(): if len(param.shape) >= 2: init.orthogonal_(param.data) else: init.normal_(param.data) elif isinstance(m, nn.GRU): for param in m.parameters(): if len(param.shape) >= 2: init.orthogonal_(param.data) else: init.normal_(param.data) elif isinstance(m, nn.GRUCell): for param in m.parameters(): if len(param.shape) >= 2: init.orthogonal_(param.data) else: init.normal_(param.data) if __name__ == '__main__': -
jeasinema revised this gist
Aug 29, 2018 . 1 changed file with 6 additions and 6 deletions.There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -14,27 +14,27 @@ def weight_init(m): ''' if isinstance(m, nn.Conv1d): init.normal(m.weight.data) if m.bias is not None: init.normal(m.bias.data) elif isinstance(m, nn.Conv2d): init.xavier_normal(m.weight.data) if m.bias is not None: init.normal(m.bias.data) elif isinstance(m, nn.Conv3d): init.xavier_normal(m.weight.data) if m.bias is not None: init.normal(m.bias.data) elif isinstance(m, nn.ConvTranspose1d): init.normal(m.weight.data) if m.bias is not None: init.normal(m.bias.data) elif isinstance(m, nn.ConvTranspose2d): init.xavier_normal(m.weight.data) if m.bias is not None: init.normal(m.bias.data) elif isinstance(m, nn.ConvTranspose3d): init.xavier_normal(m.weight.data) if m.bias is not None: init.normal(m.bias.data) elif isinstance(m, nn.BatchNorm1d): init.normal(m.weight.data, mean=1, std=0.02) -
jeasinema revised this gist
Aug 29, 2018 . 1 changed file with 12 additions and 6 deletions.There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -14,22 +14,28 @@ def weight_init(m): ''' if isinstance(m, nn.Conv1d): init.normal(m.weight.data) if m.bias: init.normal(m.bias.data) elif isinstance(m, nn.Conv2d): init.xavier_normal(m.weight.data) if m.bias: init.normal(m.bias.data) elif isinstance(m, nn.Conv3d): init.xavier_normal(m.weight.data) if m.bias: init.normal(m.bias.data) elif isinstance(m, nn.ConvTranspose1d): init.normal(m.weight.data) if m.bias: init.normal(m.bias.data) elif isinstance(m, nn.ConvTranspose2d): init.xavier_normal(m.weight.data) if m.bias: init.normal(m.bias.data) elif isinstance(m, nn.ConvTranspose3d): init.xavier_normal(m.weight.data) if m.bias: init.normal(m.bias.data) elif isinstance(m, nn.BatchNorm1d): init.normal(m.weight.data, mean=1, std=0.02) init.constant(m.bias.data, 0) -
jeasinema revised this gist
Apr 12, 2018 . 1 changed file with 5 additions and 0 deletions.There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -7,6 +7,11 @@ def weight_init(m): ''' Usage: model = Model() model.apply(weight_init) ''' if isinstance(m, nn.Conv1d): init.normal(m.weight.data) init.normal(m.bias.data) -
jeasinema created this gist
Apr 12, 2018 .There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,67 @@ #!/usr/bin/env python # -*- coding:UTF-8 -*- import torch import torch.nn as nn import torch.nn.init as init def weight_init(m): if isinstance(m, nn.Conv1d): init.normal(m.weight.data) init.normal(m.bias.data) elif isinstance(m, nn.Conv2d): init.xavier_normal(m.weight.data) init.normal(m.bias.data) elif isinstance(m, nn.Conv3d): init.xavier_normal(m.weight.data) init.normal(m.bias.data) elif isinstance(m, nn.ConvTranspose1d): init.normal(m.weight.data) init.normal(m.bias.data) elif isinstance(m, nn.ConvTranspose2d): init.xavier_normal(m.weight.data) init.normal(m.bias.data) elif isinstance(m, nn.ConvTranspose3d): init.xavier_normal(m.weight.data) init.normal(m.bias.data) elif isinstance(m, nn.BatchNorm1d): init.normal(m.weight.data, mean=1, std=0.02) init.constant(m.bias.data, 0) elif isinstance(m, nn.BatchNorm2d): init.normal(m.weight.data, mean=1, std=0.02) init.constant(m.bias.data, 0) elif isinstance(m, nn.BatchNorm3d): init.normal(m.weight.data, mean=1, std=0.02) init.constant(m.bias.data, 0) elif isinstance(m, nn.Linear): init.xavier_normal(m.weight.data) init.normal(m.bias.data) elif isinstance(m, nn.LSTM): for param in m.parameters(): if len(param.shape) >= 2: init.orthogonal(param.data) else: init.normal(param.data) elif isinstance(m, nn.LSTMCell): for param in m.parameters(): if len(param.shape) >= 2: init.orthogonal(param.data) else: init.normal(param.data) elif isinstance(m, nn.GRU): for param in m.parameters(): if len(param.shape) >= 2: init.orthogonal(param.data) else: init.normal(param.data) elif isinstance(m, nn.GRUCell): for param in m.parameters(): if len(param.shape) >= 2: init.orthogonal(param.data) else: init.normal(param.data) if __name__ == '__main__': pass