Created
March 25, 2024 12:42
-
-
Save fzliu/006d2043dc1e90d68ae562c5bde8066c to your computer and use it in GitHub Desktop.
GResNet model definition.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
gresnet.py: (Good/Great/Godlike/Gangster ResNet) | |
Implementation adapted from torchvision ResNet50 v1.4. | |
""" | |
import math | |
from typing import Any, Callable, Optional, Type, Tuple, Union | |
from torch import Tensor | |
import torch | |
import torch.nn as nn | |
from ..utils import _log_api_usage_once | |
from ._api import register_model, WeightsEnum | |
from ._utils import _ovewrite_named_param | |
__all__ = [ | |
"GResNet", | |
#"resnetd50", | |
"resnete50", | |
"gresnet50", | |
"gresnet101", | |
"gresnet152", | |
] | |
class _Affine(nn.Module): | |
def __init__(self, num_features: int, spatial_dims: int) -> None: | |
super().__init__() | |
self.num_features = num_features | |
self.spatial_dims = spatial_dims | |
dims = (num_features,) + (1,) * spatial_dims | |
self.gamma = nn.Parameter(torch.empty(dims)) | |
self.beta = nn.Parameter(torch.empty(dims)) | |
def forward(self, x: Tensor) -> Tensor: | |
x = x * self.gamma + self.beta | |
return x | |
class Affine1d(_Affine): | |
def __init__(self, num_features: int) -> None: | |
super().__init__(num_features, spatial_dims=1) | |
def extra_repr(self) -> str: | |
return "{num_features}".format(**self.__dict__) | |
class Affine2d(_Affine): | |
def __init__(self, num_features: int) -> None: | |
super().__init__(num_features, spatial_dims=2) | |
def extra_repr(self) -> str: | |
return "{num_features}".format(**self.__dict__) | |
def conv_blk( | |
in_planes: int, | |
out_planes: int, | |
kernel_size: int, | |
stride: int = 1, | |
padding: int = 0, | |
dilation: int = 1, | |
groups: int = 1, | |
norm_type: Optional[str] = "batch", | |
act_layer: Optional[Callable[..., nn.Module]] = nn.ReLU | |
) -> nn.Sequential: | |
layers = [] | |
# pre-conv norm | |
if norm_type == "split": | |
layers.append(nn.BatchNorm2d(in_planes, affine=False)) | |
# convolution | |
layers.append(nn.Conv2d( | |
in_planes, | |
out_planes, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
bias=False | |
)) | |
# post-conv norm | |
if norm_type in ["split", "affine"]: | |
layers.append(Affine2d(out_planes)) | |
elif norm_type == "batch": | |
layers.append(nn.BatchNorm2d(out_planes, affine=True)) | |
# activation | |
if act_layer: | |
layers.append(act_layer(inplace=True)) | |
return nn.Sequential(*layers) | |
class Bottleneck(nn.Module): | |
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) | |
# while original implementation places the stride at the first 1x1 convolution(self.conv1) | |
# according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385. | |
# This variant is also known as ResNet V1.5 and improves accuracy according to | |
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. | |
expansion: int = 4 | |
def __init__( | |
self, | |
in_planes: int, | |
planes: int, | |
stride: int = 1, | |
groups: int = 1, | |
base_width: int = 64, | |
downsample: Optional[nn.Module] = None, | |
norm_type: Optional[str] = "batch", | |
act_layer: Optional[Callable[..., nn.Module]] = nn.ReLU | |
) -> None: | |
super().__init__() | |
self.stride = stride | |
self.downsample = downsample | |
width = int(planes * (base_width / 64.0)) * groups | |
self.conv1 = conv_blk( | |
in_planes, | |
width, | |
1, | |
norm_type=norm_type, | |
act_layer=act_layer | |
) | |
self.conv2 = conv_blk( | |
width, | |
width, | |
3, | |
stride=stride, | |
padding=1, | |
groups=groups, | |
norm_type=norm_type, | |
act_layer=act_layer | |
) | |
self.conv3 = conv_blk( | |
width, | |
planes * self.expansion, | |
1, | |
norm_type=norm_type, | |
act_layer=None | |
) | |
self.relu = act_layer(inplace=True) | |
def forward(self, x: Tensor) -> Tensor: | |
identity = x | |
out = self.conv1(x) | |
out = self.conv2(out) | |
out = self.conv3(out) | |
if self.downsample: | |
identity = self.downsample(identity) | |
out += identity | |
out = self.relu(out) | |
return out | |
class GResNet(nn.Module): | |
def __init__( | |
self, | |
block: Type[Union[Bottleneck]], | |
layers: Tuple[int], | |
num_classes: int = 1000, | |
zero_init_residual: bool = False, | |
groups: int = 1, | |
width_per_group: int = 64, | |
in_planes: int = 128, | |
norm_type: Optional[str] = "batch", | |
act_layer: Optional[Callable[..., nn.Module]] = nn.ReLU | |
) -> None: | |
super().__init__() | |
_log_api_usage_once(self) | |
self._groups = groups | |
self._in_planes = in_planes | |
self._norm_type = norm_type | |
self._act_layer = act_layer | |
self.base_width = width_per_group | |
self.stem = self._make_stem(stem_type="tiered", use_pool=False) | |
self.layer1 = self._make_layer(block, 64, layers[0], stride=2) | |
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |
self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | |
self.classifier = self._make_classifier(512 * block.expansion, num_classes) | |
for m in self.modules(): | |
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): | |
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): | |
if m.affine: | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, (Affine1d, Affine2d)): | |
nn.init.constant_(m.gamma, 1) | |
nn.init.constant_(m.beta, 0) | |
# Zero-initialize the last BN in each residual branch, | |
# so that the residual branch starts with zeros, and each residual block behaves like an identity. | |
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | |
if zero_init_residual: | |
for m in self.modules(): | |
if isinstance(m, Bottleneck) and m.bn3.weight is not None: | |
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] | |
def _make_stem( | |
self, | |
stem_type: str = "tiered", | |
use_pool: bool = False | |
) -> nn.Sequential: | |
layers = [] | |
if stem_type == "tiered": | |
layers.append(conv_blk( | |
3, | |
self._in_planes // 4, | |
3, | |
stride=2, | |
padding=1, | |
norm_type="affine", | |
act_layer=self._act_layer | |
)) | |
layers.append(conv_blk( | |
self._in_planes // 4, | |
self._in_planes // 2, | |
3, | |
padding=1, | |
norm_type=self._norm_type, | |
act_layer=self._act_layer | |
)) | |
layers.append(conv_blk( | |
self._in_planes // 2, | |
self._in_planes, | |
3, | |
padding=1, | |
norm_type=self._norm_type, | |
act_layer=self._act_layer | |
)) | |
else: | |
layers.append(conv_blk( | |
3, | |
self._in_planes, | |
7, | |
stride=2, | |
padding=3, | |
norm_type="affine", | |
act_layer=self._act_layer | |
)) | |
if use_pool: | |
layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) | |
return nn.Sequential(*layers) | |
def _make_layer( | |
self, | |
block: Type[Bottleneck], | |
planes: int, | |
blocks: int, | |
stride: int = 1 | |
) -> nn.Sequential: | |
downsample = None | |
if stride != 1 or self._in_planes != planes * block.expansion: | |
downsample = nn.Sequential( | |
nn.AvgPool2d(stride, stride=stride, ceil_mode=True), | |
conv_blk( | |
self._in_planes, | |
planes * block.expansion, | |
1, | |
stride=1, | |
groups=self._groups, | |
norm_type=self._norm_type, | |
act_layer=None | |
) | |
) | |
layers = [] | |
layers.append( | |
block( | |
self._in_planes, | |
planes, | |
stride=stride, | |
downsample=downsample, | |
groups=self._groups, | |
base_width=self.base_width, | |
norm_type=self._norm_type, | |
act_layer=self._act_layer | |
) | |
) | |
self._in_planes = planes * block.expansion | |
for _ in range(1, blocks): | |
layers.append( | |
block( | |
self._in_planes, | |
planes, | |
groups=self._groups, | |
base_width=self.base_width, | |
norm_type=self._norm_type, | |
act_layer=self._act_layer | |
) | |
) | |
return nn.Sequential(*layers) | |
def _make_classifier( | |
self, | |
num_features: int, | |
num_classes: Optional[int] = 1000 | |
) -> nn.Sequential: | |
layers = [] | |
layers.append(nn.AdaptiveAvgPool2d((1, 1))) | |
layers.append(nn.Flatten()) | |
layers.append(nn.Linear(num_features, num_classes)) | |
return nn.Sequential(*layers) | |
def _forward_impl(self, x: Tensor) -> Tensor: | |
x = self.stem(x) | |
x = self.layer1(x) | |
x = self.layer2(x) | |
x = self.layer3(x) | |
x = self.layer4(x) | |
x = self.classifier(x) | |
return x | |
def forward(self, x: Tensor) -> Tensor: | |
return self._forward_impl(x) | |
def _gresnet( | |
block: Type[Bottleneck], | |
layers: Tuple[int], | |
weights: Optional[WeightsEnum], | |
progress: bool, | |
**kwargs: Any, | |
) -> GResNet: | |
if weights is not None: | |
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) | |
model = GResNet(block, layers, **kwargs) | |
if weights is not None: | |
model.load_state_dict(weights.get_state_dict(progress=progress)) | |
return model | |
# this is not quite ResNet-D (the stem is different) | |
#@register_model() | |
#def resnetd50(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet: | |
# return _gresnet( | |
# Bottleneck, | |
# [3, 4, 6, 3], | |
# weights, | |
# progress, | |
# **kwargs | |
# ) | |
@register_model() | |
def resnete50(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet: | |
return _gresnet( | |
Bottleneck, | |
(3, 4, 6, 3), | |
weights, | |
progress, | |
act_layer=nn.SiLU, | |
**kwargs | |
) | |
@register_model() | |
def gresnet50(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet: | |
return _gresnet( | |
Bottleneck, | |
(3, 4, 6, 3), | |
weights, | |
progress, | |
act_layer=nn.SiLU, | |
norm_type="split", | |
**kwargs | |
) | |
@register_model() | |
def gresnet101(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet: | |
return _gresnet( | |
Bottleneck, | |
(3, 4, 23, 3), | |
weights, | |
progress, | |
act_layer=nn.SiLU, | |
norm_type="split", | |
**kwargs | |
) | |
@register_model() | |
def gresnet152(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet: | |
return _gresnet( | |
Bottleneck, | |
(3, 8, 36, 3), | |
weights, | |
progress, | |
act_layer=nn.SiLU, | |
norm_type="split", | |
**kwargs | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment