Last active
May 11, 2017 10:07
-
-
Save zengyu714/044cab15e3607e6a89805ec7b36d7314 to your computer and use it in GitHub Desktop.
Pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Fintune models while add some new modules. | |
class Customize(nn.Module): | |
def __init__(self, pre_model): | |
"""Load the pretrained model, replace the last fc layers and add some new layers.""" | |
super(Customize, self).__init__() | |
self.features = pre_model | |
# If freeze previous weights | |
# ------------------------------------------- | |
# for param in self.features.parameters(): | |
# param.requires_grad = False | |
# ------------------------------------------- | |
# Set 1 channel for gray images | |
self.features.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
# Modify by suitable kernel numebrs | |
self.features.fc = nn.Linear(self.features.fc.in_features, 500) | |
# Add new modules here | |
# ------------------------------------------------------- | |
self.new_fc = nn.Sequential( | |
torch.nn.Linear(300, 100), torch.nn.ReLU(), | |
torch.nn.Linear(100, 1) | |
) | |
# ------------------------------------------------------- | |
self.init_weights() | |
def init_weights(self): | |
"""Initialize the weights.""" | |
self.features.fc.weight.data.normal_(0.0, 0.02) | |
self.features.fc.bias.data.fill_(0.01) | |
def forward(self, inputs): | |
"""Extract the image feature vectors.""" | |
output = self.new_fc(self.features(inputs)) | |
return output | |
# Initialize full-connected layers | |
from torch.nn import init | |
for m in model.modules(): | |
if isinstance(m, torch.nn.Linear): | |
init.kaiming_normal(m.weight) | |
init.constant(m.bias, 0.01) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment