Last active
February 10, 2021 04:13
-
-
Save sadimanna/306ce405d23bdea13408d8088ed9a4e9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class NewModel(nn.Module): | |
| # base_model : The model we want to get the output from | |
| # base_out_layer : The layer we want to get output from | |
| # num_trainable_layer : Number of layers we want to finetune (counted from the top) | |
| # if enetered value is -1, then all the layers are fine-tuned | |
| def __init__(self,base_model,base_out_layer,num_trainable_layer): | |
| super().__init__() | |
| self.base_model = base_model | |
| self.base_out_layer = base_out_layer | |
| self.num_trainable_layers = num_trainable_layers | |
| self.model_dict = {'resnet18':{'block':BasicBlock,'layers':[2,2,2,2],'kwargs':{}}, | |
| 'resnet34':{'block':BasicBlock,'layers':[3,4,6,3],'kwargs':{}}, | |
| 'resnet50':{'block':Bottleneck,'layers':[3,4,6,3],'kwargs':{}}, | |
| 'resnet101':{'block':Bottleneck,'layers':[3,4,23,3],'kwargs':{}}, | |
| 'resnet152':{'block':Bottleneck,'layers':[3,8,36,3],'kwargs':{}}, | |
| 'resnext50_32x4d':{'block':Bottleneck,'layers':[3,4,6,3], | |
| 'kwargs':{'groups' : 32,'width_per_group' : 4}}, | |
| 'resnext101_32x8d':{'block':Bottleneck,'layers':[3,4,23,3], | |
| 'kwargs':{'groups' : 32,'width_per_group' : 8}}, | |
| 'wide_resnet50_2':{'block':Bottleneck,'layers':[3,4,6,3], | |
| 'kwargs':{'width_per_group' : 64 * 2}}, | |
| 'wide_resnet101_2':{'block':Bottleneck,'layers':[3,4,23,3], | |
| 'kwargs':{'width_per_group' : 64 * 2}}} | |
| #PRETRAINED MODEL | |
| self.resnet = self.new_resnet(self.base_model,self.base_out_layer, | |
| self.model_dict[self.base_model]['block'], | |
| self.model_dict[self.base_model]['layers'], | |
| True,True, | |
| **self.model_dict[self.base_model]['kwargs']) | |
| self.layers = list(self.resnet._modules.keys()) | |
| #FREEZING LAYERS | |
| self.total_children = 0 | |
| self.children_counter = 0 | |
| for c in self.resnet.children(): | |
| self.total_children += 1 | |
| if self.num_trainable_layers == -1: | |
| self.num_trainable_layer = self.total_children | |
| for c in self.resnet.children(): | |
| if self.children_counter < self.total_children - self.num_trainable_layers: | |
| for param in c.parameters(): | |
| param.requires_grad = False | |
| else: | |
| for param in c.parameters(): | |
| param.requires_grad =True | |
| self.children_counter += 1 | |
| def new_resnet(self, | |
| arch: str, | |
| outlayer: str, | |
| block: Type[Union[BasicBlock, Bottleneck]], | |
| layers: List[int], | |
| pretrained: bool, | |
| progress: bool, | |
| **kwargs: Any | |
| ) -> IntResNet: | |
| '''model_urls = { | |
| 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', | |
| 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', | |
| 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', | |
| 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', | |
| 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', | |
| 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', | |
| 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', | |
| 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', | |
| 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', | |
| }''' | |
| model = IntResNet(outlayer, block, layers, **kwargs) | |
| if pretrained: | |
| state_dict = load_state_dict_from_url(model_urls[arch], | |
| progress=progress) | |
| model.load_state_dict(state_dict) | |
| return model | |
| def forward(self,x): | |
| x = self.resnet(x) | |
| return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment