Skip to content

Instantly share code, notes, and snippets.

@fzliu
Created March 25, 2024 12:42
Show Gist options
  • Save fzliu/006d2043dc1e90d68ae562c5bde8066c to your computer and use it in GitHub Desktop.
Save fzliu/006d2043dc1e90d68ae562c5bde8066c to your computer and use it in GitHub Desktop.
GResNet model definition.
"""
gresnet.py: (Good/Great/Godlike/Gangster ResNet)
Implementation adapted from torchvision ResNet50 v1.4.
"""
import math
from typing import Any, Callable, Optional, Type, Tuple, Union
from torch import Tensor
import torch
import torch.nn as nn
from ..utils import _log_api_usage_once
from ._api import register_model, WeightsEnum
from ._utils import _ovewrite_named_param
__all__ = [
"GResNet",
#"resnetd50",
"resnete50",
"gresnet50",
"gresnet101",
"gresnet152",
]
class _Affine(nn.Module):
def __init__(self, num_features: int, spatial_dims: int) -> None:
super().__init__()
self.num_features = num_features
self.spatial_dims = spatial_dims
dims = (num_features,) + (1,) * spatial_dims
self.gamma = nn.Parameter(torch.empty(dims))
self.beta = nn.Parameter(torch.empty(dims))
def forward(self, x: Tensor) -> Tensor:
x = x * self.gamma + self.beta
return x
class Affine1d(_Affine):
def __init__(self, num_features: int) -> None:
super().__init__(num_features, spatial_dims=1)
def extra_repr(self) -> str:
return "{num_features}".format(**self.__dict__)
class Affine2d(_Affine):
def __init__(self, num_features: int) -> None:
super().__init__(num_features, spatial_dims=2)
def extra_repr(self) -> str:
return "{num_features}".format(**self.__dict__)
def conv_blk(
in_planes: int,
out_planes: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
norm_type: Optional[str] = "batch",
act_layer: Optional[Callable[..., nn.Module]] = nn.ReLU
) -> nn.Sequential:
layers = []
# pre-conv norm
if norm_type == "split":
layers.append(nn.BatchNorm2d(in_planes, affine=False))
# convolution
layers.append(nn.Conv2d(
in_planes,
out_planes,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=False
))
# post-conv norm
if norm_type in ["split", "affine"]:
layers.append(Affine2d(out_planes))
elif norm_type == "batch":
layers.append(nn.BatchNorm2d(out_planes, affine=True))
# activation
if act_layer:
layers.append(act_layer(inplace=True))
return nn.Sequential(*layers)
class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
in_planes: int,
planes: int,
stride: int = 1,
groups: int = 1,
base_width: int = 64,
downsample: Optional[nn.Module] = None,
norm_type: Optional[str] = "batch",
act_layer: Optional[Callable[..., nn.Module]] = nn.ReLU
) -> None:
super().__init__()
self.stride = stride
self.downsample = downsample
width = int(planes * (base_width / 64.0)) * groups
self.conv1 = conv_blk(
in_planes,
width,
1,
norm_type=norm_type,
act_layer=act_layer
)
self.conv2 = conv_blk(
width,
width,
3,
stride=stride,
padding=1,
groups=groups,
norm_type=norm_type,
act_layer=act_layer
)
self.conv3 = conv_blk(
width,
planes * self.expansion,
1,
norm_type=norm_type,
act_layer=None
)
self.relu = act_layer(inplace=True)
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
if self.downsample:
identity = self.downsample(identity)
out += identity
out = self.relu(out)
return out
class GResNet(nn.Module):
def __init__(
self,
block: Type[Union[Bottleneck]],
layers: Tuple[int],
num_classes: int = 1000,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
in_planes: int = 128,
norm_type: Optional[str] = "batch",
act_layer: Optional[Callable[..., nn.Module]] = nn.ReLU
) -> None:
super().__init__()
_log_api_usage_once(self)
self._groups = groups
self._in_planes = in_planes
self._norm_type = norm_type
self._act_layer = act_layer
self.base_width = width_per_group
self.stem = self._make_stem(stem_type="tiered", use_pool=False)
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.classifier = self._make_classifier(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
if m.affine:
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, (Affine1d, Affine2d)):
nn.init.constant_(m.gamma, 1)
nn.init.constant_(m.beta, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck) and m.bn3.weight is not None:
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
def _make_stem(
self,
stem_type: str = "tiered",
use_pool: bool = False
) -> nn.Sequential:
layers = []
if stem_type == "tiered":
layers.append(conv_blk(
3,
self._in_planes // 4,
3,
stride=2,
padding=1,
norm_type="affine",
act_layer=self._act_layer
))
layers.append(conv_blk(
self._in_planes // 4,
self._in_planes // 2,
3,
padding=1,
norm_type=self._norm_type,
act_layer=self._act_layer
))
layers.append(conv_blk(
self._in_planes // 2,
self._in_planes,
3,
padding=1,
norm_type=self._norm_type,
act_layer=self._act_layer
))
else:
layers.append(conv_blk(
3,
self._in_planes,
7,
stride=2,
padding=3,
norm_type="affine",
act_layer=self._act_layer
))
if use_pool:
layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
return nn.Sequential(*layers)
def _make_layer(
self,
block: Type[Bottleneck],
planes: int,
blocks: int,
stride: int = 1
) -> nn.Sequential:
downsample = None
if stride != 1 or self._in_planes != planes * block.expansion:
downsample = nn.Sequential(
nn.AvgPool2d(stride, stride=stride, ceil_mode=True),
conv_blk(
self._in_planes,
planes * block.expansion,
1,
stride=1,
groups=self._groups,
norm_type=self._norm_type,
act_layer=None
)
)
layers = []
layers.append(
block(
self._in_planes,
planes,
stride=stride,
downsample=downsample,
groups=self._groups,
base_width=self.base_width,
norm_type=self._norm_type,
act_layer=self._act_layer
)
)
self._in_planes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self._in_planes,
planes,
groups=self._groups,
base_width=self.base_width,
norm_type=self._norm_type,
act_layer=self._act_layer
)
)
return nn.Sequential(*layers)
def _make_classifier(
self,
num_features: int,
num_classes: Optional[int] = 1000
) -> nn.Sequential:
layers = []
layers.append(nn.AdaptiveAvgPool2d((1, 1)))
layers.append(nn.Flatten())
layers.append(nn.Linear(num_features, num_classes))
return nn.Sequential(*layers)
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.classifier(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def _gresnet(
block: Type[Bottleneck],
layers: Tuple[int],
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> GResNet:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = GResNet(block, layers, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
# this is not quite ResNet-D (the stem is different)
#@register_model()
#def resnetd50(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet:
# return _gresnet(
# Bottleneck,
# [3, 4, 6, 3],
# weights,
# progress,
# **kwargs
# )
@register_model()
def resnete50(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet:
return _gresnet(
Bottleneck,
(3, 4, 6, 3),
weights,
progress,
act_layer=nn.SiLU,
**kwargs
)
@register_model()
def gresnet50(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet:
return _gresnet(
Bottleneck,
(3, 4, 6, 3),
weights,
progress,
act_layer=nn.SiLU,
norm_type="split",
**kwargs
)
@register_model()
def gresnet101(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet:
return _gresnet(
Bottleneck,
(3, 4, 23, 3),
weights,
progress,
act_layer=nn.SiLU,
norm_type="split",
**kwargs
)
@register_model()
def gresnet152(*, weights: Optional = None, progress: bool = True, **kwargs: Any) -> GResNet:
return _gresnet(
Bottleneck,
(3, 8, 36, 3),
weights,
progress,
act_layer=nn.SiLU,
norm_type="split",
**kwargs
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment