Skip to content

Instantly share code, notes, and snippets.

@jeasinema
Last active November 22, 2024 07:17

Revisions

  1. jeasinema revised this gist Aug 29, 2018. 1 changed file with 28 additions and 28 deletions.
    56 changes: 28 additions & 28 deletions weight_init.py
    Original file line number Diff line number Diff line change
    @@ -13,65 +13,65 @@ def weight_init(m):
    model.apply(weight_init)
    '''
    if isinstance(m, nn.Conv1d):
    init.normal(m.weight.data)
    init.normal_(m.weight.data)
    if m.bias is not None:
    init.normal(m.bias.data)
    init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv2d):
    init.xavier_normal(m.weight.data)
    init.xavier_normal_(m.weight.data)
    if m.bias is not None:
    init.normal(m.bias.data)
    init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv3d):
    init.xavier_normal(m.weight.data)
    init.xavier_normal_(m.weight.data)
    if m.bias is not None:
    init.normal(m.bias.data)
    init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose1d):
    init.normal(m.weight.data)
    init.normal_(m.weight.data)
    if m.bias is not None:
    init.normal(m.bias.data)
    init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose2d):
    init.xavier_normal(m.weight.data)
    init.xavier_normal_(m.weight.data)
    if m.bias is not None:
    init.normal(m.bias.data)
    init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose3d):
    init.xavier_normal(m.weight.data)
    init.xavier_normal_(m.weight.data)
    if m.bias is not None:
    init.normal(m.bias.data)
    init.normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm1d):
    init.normal(m.weight.data, mean=1, std=0.02)
    init.constant(m.bias.data, 0)
    init.normal_(m.weight.data, mean=1, std=0.02)
    init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
    init.normal(m.weight.data, mean=1, std=0.02)
    init.constant(m.bias.data, 0)
    init.normal_(m.weight.data, mean=1, std=0.02)
    init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm3d):
    init.normal(m.weight.data, mean=1, std=0.02)
    init.constant(m.bias.data, 0)
    init.normal_(m.weight.data, mean=1, std=0.02)
    init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
    init.xavier_normal(m.weight.data)
    init.normal(m.bias.data)
    init.xavier_normal_(m.weight.data)
    init.normal_(m.bias.data)
    elif isinstance(m, nn.LSTM):
    for param in m.parameters():
    if len(param.shape) >= 2:
    init.orthogonal(param.data)
    init.orthogonal_(param.data)
    else:
    init.normal(param.data)
    init.normal_(param.data)
    elif isinstance(m, nn.LSTMCell):
    for param in m.parameters():
    if len(param.shape) >= 2:
    init.orthogonal(param.data)
    init.orthogonal_(param.data)
    else:
    init.normal(param.data)
    init.normal_(param.data)
    elif isinstance(m, nn.GRU):
    for param in m.parameters():
    if len(param.shape) >= 2:
    init.orthogonal(param.data)
    init.orthogonal_(param.data)
    else:
    init.normal(param.data)
    init.normal_(param.data)
    elif isinstance(m, nn.GRUCell):
    for param in m.parameters():
    if len(param.shape) >= 2:
    init.orthogonal(param.data)
    init.orthogonal_(param.data)
    else:
    init.normal(param.data)
    init.normal_(param.data)


    if __name__ == '__main__':
  2. jeasinema revised this gist Aug 29, 2018. 1 changed file with 6 additions and 6 deletions.
    12 changes: 6 additions & 6 deletions weight_init.py
    Original file line number Diff line number Diff line change
    @@ -14,27 +14,27 @@ def weight_init(m):
    '''
    if isinstance(m, nn.Conv1d):
    init.normal(m.weight.data)
    if m.bias:
    if m.bias is not None:
    init.normal(m.bias.data)
    elif isinstance(m, nn.Conv2d):
    init.xavier_normal(m.weight.data)
    if m.bias:
    if m.bias is not None:
    init.normal(m.bias.data)
    elif isinstance(m, nn.Conv3d):
    init.xavier_normal(m.weight.data)
    if m.bias:
    if m.bias is not None:
    init.normal(m.bias.data)
    elif isinstance(m, nn.ConvTranspose1d):
    init.normal(m.weight.data)
    if m.bias:
    if m.bias is not None:
    init.normal(m.bias.data)
    elif isinstance(m, nn.ConvTranspose2d):
    init.xavier_normal(m.weight.data)
    if m.bias:
    if m.bias is not None:
    init.normal(m.bias.data)
    elif isinstance(m, nn.ConvTranspose3d):
    init.xavier_normal(m.weight.data)
    if m.bias:
    if m.bias is not None:
    init.normal(m.bias.data)
    elif isinstance(m, nn.BatchNorm1d):
    init.normal(m.weight.data, mean=1, std=0.02)
  3. jeasinema revised this gist Aug 29, 2018. 1 changed file with 12 additions and 6 deletions.
    18 changes: 12 additions & 6 deletions weight_init.py
    Original file line number Diff line number Diff line change
    @@ -14,22 +14,28 @@ def weight_init(m):
    '''
    if isinstance(m, nn.Conv1d):
    init.normal(m.weight.data)
    init.normal(m.bias.data)
    if m.bias:
    init.normal(m.bias.data)
    elif isinstance(m, nn.Conv2d):
    init.xavier_normal(m.weight.data)
    init.normal(m.bias.data)
    if m.bias:
    init.normal(m.bias.data)
    elif isinstance(m, nn.Conv3d):
    init.xavier_normal(m.weight.data)
    init.normal(m.bias.data)
    if m.bias:
    init.normal(m.bias.data)
    elif isinstance(m, nn.ConvTranspose1d):
    init.normal(m.weight.data)
    init.normal(m.bias.data)
    if m.bias:
    init.normal(m.bias.data)
    elif isinstance(m, nn.ConvTranspose2d):
    init.xavier_normal(m.weight.data)
    init.normal(m.bias.data)
    if m.bias:
    init.normal(m.bias.data)
    elif isinstance(m, nn.ConvTranspose3d):
    init.xavier_normal(m.weight.data)
    init.normal(m.bias.data)
    if m.bias:
    init.normal(m.bias.data)
    elif isinstance(m, nn.BatchNorm1d):
    init.normal(m.weight.data, mean=1, std=0.02)
    init.constant(m.bias.data, 0)
  4. jeasinema revised this gist Apr 12, 2018. 1 changed file with 5 additions and 0 deletions.
    5 changes: 5 additions & 0 deletions weight_init.py
    Original file line number Diff line number Diff line change
    @@ -7,6 +7,11 @@


    def weight_init(m):
    '''
    Usage:
    model = Model()
    model.apply(weight_init)
    '''
    if isinstance(m, nn.Conv1d):
    init.normal(m.weight.data)
    init.normal(m.bias.data)
  5. jeasinema created this gist Apr 12, 2018.
    67 changes: 67 additions & 0 deletions weight_init.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,67 @@
    #!/usr/bin/env python
    # -*- coding:UTF-8 -*-

    import torch
    import torch.nn as nn
    import torch.nn.init as init


    def weight_init(m):
    if isinstance(m, nn.Conv1d):
    init.normal(m.weight.data)
    init.normal(m.bias.data)
    elif isinstance(m, nn.Conv2d):
    init.xavier_normal(m.weight.data)
    init.normal(m.bias.data)
    elif isinstance(m, nn.Conv3d):
    init.xavier_normal(m.weight.data)
    init.normal(m.bias.data)
    elif isinstance(m, nn.ConvTranspose1d):
    init.normal(m.weight.data)
    init.normal(m.bias.data)
    elif isinstance(m, nn.ConvTranspose2d):
    init.xavier_normal(m.weight.data)
    init.normal(m.bias.data)
    elif isinstance(m, nn.ConvTranspose3d):
    init.xavier_normal(m.weight.data)
    init.normal(m.bias.data)
    elif isinstance(m, nn.BatchNorm1d):
    init.normal(m.weight.data, mean=1, std=0.02)
    init.constant(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
    init.normal(m.weight.data, mean=1, std=0.02)
    init.constant(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm3d):
    init.normal(m.weight.data, mean=1, std=0.02)
    init.constant(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
    init.xavier_normal(m.weight.data)
    init.normal(m.bias.data)
    elif isinstance(m, nn.LSTM):
    for param in m.parameters():
    if len(param.shape) >= 2:
    init.orthogonal(param.data)
    else:
    init.normal(param.data)
    elif isinstance(m, nn.LSTMCell):
    for param in m.parameters():
    if len(param.shape) >= 2:
    init.orthogonal(param.data)
    else:
    init.normal(param.data)
    elif isinstance(m, nn.GRU):
    for param in m.parameters():
    if len(param.shape) >= 2:
    init.orthogonal(param.data)
    else:
    init.normal(param.data)
    elif isinstance(m, nn.GRUCell):
    for param in m.parameters():
    if len(param.shape) >= 2:
    init.orthogonal(param.data)
    else:
    init.normal(param.data)


    if __name__ == '__main__':
    pass