-
-
Save Ryu1845/30deecc24d947d9ccac3468e67b4b504 to your computer and use it in GitHub Desktop.
A ConvNet for the 2020s (1D version)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.init import trunc_normal_ | |
def drop_path( | |
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True | |
): | |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | |
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... | |
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | |
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use | |
'survival rate' as the argument. | |
""" | |
if drop_prob == 0.0 or not training: | |
return x | |
keep_prob = 1 - drop_prob | |
shape = (x.shape[0],) + (1,) * ( | |
x.ndim - 1 | |
) # work with diff dim tensors, not just 2D ConvNets | |
random_tensor = x.new_empty(shape).bernoulli_(keep_prob) | |
if keep_prob > 0.0 and scale_by_keep: | |
random_tensor.div_(keep_prob) | |
return x * random_tensor | |
class DropPath(nn.Module): | |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
""" | |
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): | |
super(DropPath, self).__init__() | |
self.drop_prob = drop_prob | |
self.scale_by_keep = scale_by_keep | |
def forward(self, x): | |
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) | |
def extra_repr(self): | |
return f'drop_prob={round(self.drop_prob,3):0.3f}' | |
class Block(nn.Module): | |
r""" ConvNeXt Block. There are two equivalent implementations: | |
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) | |
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back | |
We use (2) as we find it slightly faster in PyTorch | |
Args: | |
dim (int): Number of input channels. | |
drop_path (float): Stochastic depth rate. Default: 0.0 | |
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. | |
""" | |
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): | |
super().__init__() | |
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv | |
self.norm = LayerNorm(dim, eps=1e-6) | |
self.pwconv1 = nn.Linear(dim, 4 * dim)# pointwise/1x1 convs, implemented with linear layers | |
self.act = nn.GELU() | |
self.pwconv2 = nn.Linear(4 * dim, dim) | |
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), | |
requires_grad=True) if layer_scale_init_value > 0 else None | |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
def forward(self, x): | |
input = x | |
x = self.dwconv(x) | |
x = x.permute(0, 2, 1) # (N, C, H, W) -> (N, H, W, C) | |
x = self.norm(x) | |
x = self.pwconv1(x) | |
x = self.act(x) | |
x = self.pwconv2(x) | |
if self.gamma is not None: | |
x = self.gamma * x | |
x = x.permute(0, 2, 1) # (N, H, W, C) -> (N, C, H, W) | |
x = input + self.drop_path(x) | |
return x | |
class ConvNeXt(nn.Module): | |
r""" ConvNeXt | |
A PyTorch impl of : `A ConvNet for the 2020s` - | |
https://arxiv.org/pdf/2201.03545.pdf | |
Args: | |
in_chans (int): Number of input image channels. Default: 3 | |
num_classes (int): Number of classes for classification head. Default: 1000 | |
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] | |
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] | |
drop_path_rate (float): Stochastic depth rate. Default: 0. | |
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. | |
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. | |
""" | |
def __init__(self, in_chans=2, num_classes=43, | |
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., | |
layer_scale_init_value=1e-6, head_init_scale=1., | |
): | |
super().__init__() | |
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers | |
stem = nn.Sequential( | |
nn.Conv1d(in_chans, dims[0], kernel_size=4, stride=4), | |
# nn.BatchNorm1d(dims[0]) | |
LayerNorm(dims[0], eps=1e-6, data_format="channels_first") | |
) | |
self.downsample_layers.append(stem) | |
for i in range(3): | |
downsample_layer = nn.Sequential( | |
# nn.BatchNorm1d(dims[i]), | |
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), | |
nn.Conv1d(dims[i], dims[i + 1], kernel_size=2, stride=2), | |
) | |
self.downsample_layers.append(downsample_layer) | |
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks | |
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] | |
cur = 0 | |
for i in range(4): | |
stage = nn.Sequential( | |
*[Block(dim=dims[i], drop_path=dp_rates[cur + j], | |
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] | |
) | |
self.stages.append(stage) | |
cur += depths[i] | |
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer | |
self.head = nn.Linear(dims[-1], num_classes) | |
self.apply(self._init_weights) | |
self.head.weight.data.mul_(head_init_scale) | |
self.head.bias.data.mul_(head_init_scale) | |
def _init_weights(self, m): | |
if isinstance(m, (nn.Conv1d, nn.Linear)): | |
trunc_normal_(m.weight, std=.02) | |
nn.init.constant_(m.bias, 0) | |
def forward_features(self, x): | |
for i in range(4): | |
x = self.downsample_layers[i](x) | |
x = self.stages[i](x) | |
return x.mean([-1]) | |
# return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) | |
def forward(self, x): | |
x = self.forward_features(x) | |
x = self.head(x) | |
return x | |
class LayerNorm(nn.Module): | |
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. | |
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with | |
shape (batch_size, height, width, channels) while channels_first corresponds to inputs | |
with shape (batch_size, channels, height, width). | |
""" | |
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
self.eps = eps | |
self.data_format = data_format | |
if self.data_format not in ["channels_last", "channels_first"]: | |
raise NotImplementedError | |
self.normalized_shape = (normalized_shape,) | |
def forward(self, x): | |
if self.data_format == "channels_last": | |
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
elif self.data_format == "channels_first": | |
u = x.mean(1, keepdim=True) | |
s = (x - u).pow(2).mean(1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.eps) | |
x = self.weight[:, None] * x + self.bias[:, None] | |
return x | |
model_urls = { | |
"convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", | |
"convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", | |
"convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", | |
"convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", | |
"convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", | |
"convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", | |
"convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", | |
} | |
def convnext_tiny(pretrained=False, **kwargs): | |
model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) | |
if pretrained: | |
url = model_urls['convnext_tiny_1k'] | |
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) | |
model.load_state_dict(checkpoint["model"]) | |
return model | |
def convnext_small(pretrained=False, **kwargs): | |
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) | |
if pretrained: | |
url = model_urls['convnext_small_1k'] | |
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") | |
model.load_state_dict(checkpoint["model"]) | |
return model | |
def convnext_base(pretrained=False, in_22k=False, **kwargs): | |
model = ConvNeXt(depths=[1, 1, 3, 1], dims=[128, 256, 512, 1024], **kwargs) | |
if pretrained: | |
url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] | |
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") | |
model.load_state_dict(checkpoint["model"]) | |
return model | |
def convnext_large(pretrained=False, in_22k=False, **kwargs): | |
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) | |
if pretrained: | |
url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] | |
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") | |
model.load_state_dict(checkpoint["model"]) | |
return model | |
def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): | |
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) | |
if pretrained: | |
assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" | |
url = model_urls['convnext_xlarge_22k'] | |
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") | |
model.load_state_dict(checkpoint["model"]) | |
return model | |
def main(): | |
conv = convnext_tiny() | |
a = conv(torch.zeros((16,2,1024))) | |
print(conv) | |
if __name__ == '__main__': | |
# freeze_support() | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment