Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Last active January 21, 2021 14:33
Show Gist options
  • Select an option

  • Save sadimanna/d27dcbd9851bbf2776c48c0e7cced796 to your computer and use it in GitHub Desktop.

Select an option

Save sadimanna/d27dcbd9851bbf2776c48c0e7cced796 to your computer and use it in GitHub Desktop.
class NewModel(nn.Module):
def __init__(self,
base_model: str,
batch_norm: bool,
base_out_layer: int,
num_trainable_layers: int) -> None:
super().__init__()
self.base_model = base_model
self.batch_norm = batch_norm
self.base_out_layer = base_out_layer
self.num_trainable_layers = num_trainable_layers
self.cfg_dict = {'vgg11':'A',
'vgg13':'B',
'vgg16':'D',
'vgg19':'E'}
self.vgg = self._vgg(self.base_model,
self.cfg_dict[self.base_model],
self.batch_norm,
self.base_out_layer,
True, True)
self.total_children = 0
self.children_counter = 0
for c in self.vgg.children():
self.total_children += 1
if self.num_trainable_layers == -1:
self.num_trainable_layers = self.total_children
for c in self.vgg.children():
if self.children_counter < self.total_children - self.num_trainable_layers:
for param in c.parameters():
param.requires_grad = False
else:
for param in c.parameters():
param.requires_grad = True
self.children_counter += 1
def make_layers(self,
cfg: List[Union[str, int]],
base_out_layer: int,
batch_norm: bool = False) -> nn.Sequential:
layers: List[nn.Module] = []
in_channels = 3
layer_count = 0
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
v = cast(int, v)
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
layer_count += 1
if layer_count == base_out_layer:
break
return nn.Sequential(*layers)
def _vgg(self,
arch: str,
cfg: str,
batch_norm: bool,
base_out_layer: int,
pretrained: bool,
progress: bool,
**kwargs: Any) -> newVGG:
model = newVGG(self.make_layers(cfgs[cfg],
base_out_layer,
batch_norm=batch_norm),
**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict, strict = False)
return model
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.vgg(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment