Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Last active January 7, 2021 03:07
Show Gist options
  • Select an option

  • Save sadimanna/de7a5550d21ec13d8d7961746a114502 to your computer and use it in GitHub Desktop.

Select an option

Save sadimanna/de7a5550d21ec13d8d7961746a114502 to your computer and use it in GitHub Desktop.
class newVGG(nn.Module):
def __init__(self,
features: nn.Module,
**kwargs: Any) -> None:
super().__init__()
self.features = features
self.kwargs = kwargs
if self.kwargs['init_weights']:
self._initialize_weights()
def _initialize_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment