Skip to content

Instantly share code, notes, and snippets.

@yearofthewhopper
Created February 22, 2020 18:26
Show Gist options
  • Save yearofthewhopper/3a42bc3a639013e0ec0faa5d30a31225 to your computer and use it in GitHub Desktop.
Save yearofthewhopper/3a42bc3a639013e0ec0faa5d30a31225 to your computer and use it in GitHub Desktop.
# 1
def weights_init_normal(m):
# 2
def weights_init_xavier(m):
# 3
def weights_init_kaiming(m):
# 4
def weights_init_orthogonal(m):
classname = m.__class__.__name__
# print(classname)
if classname.find("Conv") != -1:
# 1
init.normal_(m.weight.data, 0.0, 0.02)
# 2
init.xavier_normal_(m.weight.data, gain=0.02)
# 3
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
# 4
init.orthogonal(m.weight.data, gain=1)
elif classname.find("Linear") != -1:
# 1
init.normal_(m.weight.data, 0.0, 0.02)
# 2
init.xavier_normal_(m.weight.data, gain=0.02)
# 3
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
# 4
init.orthogonal(m.weight.data, gain=1)
elif classname.find("BatchNorm2d") != -1:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
def init_layer(L):
print("init_layer")
# Initialization using fan-in
if isinstance(L, nn.Conv2d):
n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels
L.weight.data.normal_(0,math.sqrt(2.0/float(n)))
elif isinstance(L, nn.BatchNorm2d):
L.weight.data.fill_(1)
L.bias.data.fill_(0)
def _initialization(self):
for m in self.modules():
if isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
def init_weights(net, init_type='normal'):
print('initialization method [%s]' % init_type)
if init_type == 'normal':
net.apply(weights_init_normal)
elif init_type == 'xavier':
net.apply(weights_init_xavier)
elif init_type == 'kaiming':
net.apply(weights_init_kaiming)
elif init_type == 'orthogonal':
net.apply(weights_init_orthogonal)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment