Last active
September 2, 2023 16:39
-
-
Save ndgnuh/38e36fb5945aaccc4f823762407168d3 to your computer and use it in GitHub Desktop.
Hourglass Pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Hourglass network, the backbone part. | |
Implemented according to the CornetNet paper. The ArXiv version does not have the backbone description. | |
Reference: | |
- https://link.springer.com/article/10.1007/s11263-019-01204-1 | |
- https://sci-hub.se/https://link.springer.com/article/10.1007/s11263-019-01204-1 (I won't use 40EUR much money *just* to read a paper. Where I live one can buy 10 books with that.) | |
- https://arxiv.org/abs/1808.01244 (not really related to this implementation, just in case someone is curious about CornetNet) | |
Public API: | |
- Hourglass104 | |
- build_hourglass | |
""" | |
from typing import List | |
from torch import nn | |
def ConvBR(*args, relu=True, **kwargs): | |
"""Conv2d, BatchNorm, (maybe) ReLU""" | |
conv = nn.Conv2d(*args, **kwargs) | |
norm = nn.BatchNorm2d(conv.out_channels) | |
if relu: | |
relu = nn.ReLU(True) | |
else: | |
relu = nn.Identity() | |
return nn.Sequential(conv, norm, relu) | |
class ResidualBlock(nn.Module): | |
""" | |
This is actually the Bottleneck Block in the Resnet paper. | |
It can be replaced with other Residual Block variants. | |
""" | |
def __init__(self, in_channels, out_channels, stride: int = 1, reduction: int = 4): | |
super().__init__() | |
hid_channels = in_channels // reduction | |
self.conv = nn.Sequential( | |
*ConvBR(in_channels, hid_channels, 1), | |
*ConvBR(hid_channels, hid_channels, 3, stride, padding=1), | |
*ConvBR(hid_channels, out_channels, 1), | |
) | |
if in_channels != out_channels or stride != 1: | |
self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=stride) | |
else: | |
self.skip = nn.Identity() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.stride = stride | |
def __repr__(self): | |
"""Trust me, this will make the network REPR much more readable""" | |
return f"ResidualBlock({self.in_channels}, {self.out_channels}, stride={self.stride})" | |
def forward(self, x): | |
return self.conv(x) + self.skip(x) | |
class HGLCore(nn.Module): | |
"""The core of the hourglass, see `build_hourglass`""" | |
def __init__(self, hidden_size: int, num_layers: int = 4): | |
super().__init__() | |
layers = [ResidualBlock(hidden_size, hidden_size) for _ in range(num_layers)] | |
self.core = nn.Sequential(*layers) | |
self.hidden_size = hidden_size | |
def forward(self, x): | |
return self.core(x) | |
def get_hidden_size(self): | |
return self.hidden_size | |
class HGLCoreAtn(nn.Module): | |
"""The core of the hourglass, see `build_hourglass`. | |
Since the dimension at the core is so low, | |
it *might* be a good idea to use transformer encoder. | |
""" | |
def __init__(self, hidden_size: int, num_layers: int = 4): | |
super().__init__() | |
tfm = nn.TransformerEncoderLayer(hidden_size, 4, batch_first=True) | |
self.core = nn.TransformerEncoder(tfm, num_layers) | |
self.hidden_size = hidden_size | |
def forward(self, x): | |
B, C, H, W = x.shape | |
x = x.flatten(-2).transpose(-2, -1) | |
x = self.core(x) | |
x = x.transpose(-2, -1).reshape(B, C, H, W) | |
return x | |
def get_hidden_size(self): | |
return self.hidden_size | |
class HGLLevel(nn.Module): | |
"""The addition layer wrapper for the core or previous layer of hourglass | |
See `build_hourglass` for details. | |
""" | |
def __init__(self, core: nn.Module, in_channels: int): | |
super().__init__() | |
self.in_channels = in_channels | |
hidden_size = core.get_hidden_size() | |
self.pre_core = nn.Sequential( | |
ResidualBlock(in_channels, hidden_size, stride=2), | |
ResidualBlock(hidden_size, hidden_size), | |
) | |
self.core = core | |
self.post_core = nn.Sequential( | |
ResidualBlock(hidden_size, hidden_size), | |
ResidualBlock(hidden_size, in_channels), | |
nn.Upsample(scale_factor=2, mode="bilinear"), | |
) | |
self.skip = nn.Sequential( | |
ResidualBlock(in_channels, in_channels), | |
ResidualBlock(in_channels, in_channels), | |
) | |
def get_hidden_size(self): | |
return self.in_channels | |
def forward(self, x): | |
skip = self.skip(x) | |
x = self.pre_core(x) | |
x = self.core(x) | |
x = self.post_core(x) | |
x = x + skip | |
return x | |
def build_hourglass( | |
hidden_sizes: List[int] = [256, 256, 256, 384, 384, 512], | |
attention_core: bool = False, | |
): | |
"""Build an hourglass module. | |
The idea is to wrap the levels of the hourglass one by one. | |
The inner level is referred to as the `core`. | |
We wrap the core with some projection layers and and a skip connection: | |
``` | |
wrapper(hourglass-core, x): | |
skip <- skip-connection(x) | |
x <- input-projection(x) | |
x <- hourglass-core(x) | |
x <- output-projection(x) | |
x <- x + skip | |
return x | |
``` | |
The first core is `HGLCore`, the wrapper layer is `HGLLevel`. | |
After wrapping a core, the output becomes a core itself. | |
This function basically does the following: | |
``` | |
hourglass = HGLLevel(...HGLLevel(HGLCore(...))...) | |
``` | |
Args: | |
hidden_sizes (List[int]): | |
The channels for each level of the hourglass. | |
The last entry in `hidden_sizes` will be the hidden_size | |
in the middle of the hourglass. | |
attention_core (bool): | |
If true, `HGLCoreAtn` will be used instead of `HGLCore`. | |
Default: false. | |
""" | |
# The real hour glass is the nested network we build along the way | |
hourglass = None | |
# Which core to use | |
if attention_core: | |
Core = HGLCoreAtn | |
else: | |
Core = HGLCore | |
# Build along the way | |
for i, hidden in enumerate(reversed(hidden_sizes)): | |
if i == 0: | |
hourglass = Core(hidden) | |
else: | |
hourglass = HGLLevel(hourglass, hidden) | |
return hourglass | |
class Hourglass104(nn.Module): | |
"""Hourglass104 network. | |
This module only returns the final and the middle feature maps, | |
in that order, and not the predictions. | |
Side-note: | |
Since this implementation use Bottleneck block, it's actually | |
Hourglass 138. | |
The type of block to use, I think that, in the end, it doesn't even matter. | |
One can swap out the basic building block for an equivalent one and still | |
have same overall architecture. | |
I named this thing 104 because that's the way it is from the paper. | |
Maybe the generalized version (in which blocks can be swapped) should be | |
called something else. | |
Args: | |
hidden_sizes (List[int]): | |
Input to `build_hourglass`. The first hidden size | |
will be the channels of the output feature maps. | |
Default: [256, 384, 384, 384, 512, 512]. | |
By changing the default value, you can change the | |
computation requirement of the model and maybe save some FLOPS. | |
attention_core (bool): | |
Whether to use `HGLCoreAtn`. Default: `False`. | |
""" | |
def __init__( | |
self, | |
hidden_sizes: List[int] = [256, 384, 384, 384, 512, 512], | |
attention_core: bool = False, | |
): | |
super().__init__() | |
h0 = hidden_sizes[0] | |
self.stem = nn.Sequential( | |
*ConvBR(3, h0 // 2, 7, 2, padding=3), | |
*ConvBR(h0 // 2, h0, 3, 2, padding=1), | |
) | |
# First hourglass | |
self.hourglass_1 = nn.Sequential( | |
build_hourglass(hidden_sizes, attention_core), | |
ConvBR(h0, h0, 3, padding=1), | |
) | |
self.skip_1 = ConvBR(h0, h0, 1, relu=False) | |
self.project_1 = ConvBR(h0, h0, 1, relu=False) | |
# Second hourglass | |
self.hourglass_2 = nn.Sequential( | |
nn.ReLU(True), | |
ConvBR(h0, h0, 3, padding=1), | |
build_hourglass(hidden_sizes, attention_core), | |
ConvBR(h0, h0, 3, padding=1), | |
) | |
def forward(self, x): | |
# Prepare | |
x = self.stem(x) | |
outputs = [] | |
# First hourglass forward | |
skip = self.skip_1(x) | |
x = self.hourglass_1(x) | |
outputs.insert(0, x) | |
x = self.project_1(x) + skip | |
# Second hourglass forward | |
x = self.hourglass_2(x) | |
outputs.insert(0, x) | |
# Return two feature maps | |
return outputs |
Author
ndgnuh
commented
Aug 25, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment