This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def forward_pre_hook(self,layer_name): | |
| def pre_hook(module, input): | |
| return torch.zeros(input[0].shape,dtype=torch.float,device='cuda:0',requires_grad = True) | |
| return pre_hook |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class NewModel(nn.Module): | |
| def __init__(self, output_layers, *args): | |
| super().__init__(*args) | |
| self.output_layers = output_layers | |
| #print(self.output_layers) | |
| self.selected_out = OrderedDict() | |
| #PRETRAINED MODEL | |
| self.pretrained = models.resnet50(pretrained=True) | |
| self.fhooks = [] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| from torchvision.models.utils import load_state_dict_from_url | |
| from typing import Type, Any, Callable, Union, List, Optional, cast | |
| from torch import Tensor | |
| from collections import OrderedDict |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class NewModel(nn.Module): | |
| def __init__(self, | |
| base_model: str, | |
| batch_norm: bool, | |
| base_out_layer: int, | |
| num_trainable_layers: int) -> None: | |
| super().__init__() | |
| self.base_model = base_model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class newVGG(nn.Module): | |
| def __init__(self, | |
| features: nn.Module, | |
| **kwargs: Any) -> None: | |
| super().__init__() | |
| self.features = features | |
| self.kwargs = kwargs | |
| if self.kwargs['init_weights']: | |
| self._initialize_weights() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class newVGG(VGG): | |
| def __init__(self, | |
| features: nn.Module, | |
| **kwargs: Any) -> None: | |
| super().__init__(features,**kwargs) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.features(x) | |
| return x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class NewModel(nn.Module): | |
| # base_model : The model we want to get the output from | |
| # base_out_layer : The layer we want to get output from | |
| # num_trainable_layer : Number of layers we want to finetune (counted from the top) | |
| # if enetered value is -1, then all the layers are fine-tuned | |
| def __init__(self,base_model,base_out_layer,num_trainable_layer): | |
| super().__init__() | |
| self.base_model = base_model | |
| self.base_out_layer = base_out_layer | |
| self.num_trainable_layers = num_trainable_layers |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| model = new_resnet('resnet50','layer4',Bottleneck, [3, 4, 6, 3],True,True) | |
| model = model.to('cuda:0') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def new_resnet( | |
| arch: str, | |
| outlayer: str, | |
| block: Type[Union[BasicBlock, Bottleneck]], | |
| layers: List[int], | |
| pretrained: bool, | |
| progress: bool, | |
| **kwargs: Any | |
| ) -> IntResNet: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class IntResNet(ResNet): | |
| def __init__(self,output_layer,*args): | |
| self.output_layer = output_layer | |
| super().__init__(*args) | |
| self._layers = [] | |
| for l in list(self._modules.keys()): | |
| self._layers.append(l) | |
| if l == output_layer: | |
| break |