Created
June 6, 2022 23:52
-
-
Save gabrieldernbach/1b3d1fdb44233fa2e64877a16097cc7b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
class Residual(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.layer = nn.Sequential( | |
nn.Conv2d(dim, dim, 7, 1, 3, groups=dim), | |
nn.BachNorm2d(dim), | |
nn.Conv2d(dim, dim*4, 1), | |
nn.ReLU(), | |
nn.Conv2d(dim*4, dim, 1), | |
) | |
def forward(self, x): | |
return self.layer(x) + x | |
def block(ins, outs, repeats): | |
layers = [Residual(ins) for _ in range(repeats)] | |
if in != outs: | |
layers.extend([ | |
nn.BatchNorm2d(ins), | |
nn.Conv2d(ins, outs, 2, 2), | |
]) | |
return nn.Sequential(*layers) | |
def convnet18(outs): | |
features = nn.Sequential( | |
nn.Conv2d(3, 64, 4, 4) # e.g. (3, 224, 224) -> (64, 56, 56) | |
block(64, 128, 2), | |
block(128, 256, 2), | |
block(256, 512, 2), | |
block(512, 512, 2), # remaining (512, 7, 7)) | |
nn.AdaptiveAvgPool2d(1), # remaining (512) | |
nn.Flatten(), | |
) | |
return nn.Sequential(features, nn.Linear(512, outs)) | |
# This network is a modification of resnet18 inspired by | |
# the observations reported in https://arxiv.org/abs/2201.03545 (A ConvNet for the 2020s) | |
# we try to remain close in n_params to renset18's 11 mio (achieved 12mio). | |
# comparison to resnet18 (and following A ConvNet for the 2020s) | |
# stronger stemming (factor two vs factor 4) | |
# wider kernels (7 vs 3) | |
# less normalization layers | |
# less activation layers | |
# use of depth-wise conv (much faster to compute!) | |
# inverted residual in the point-wise convs (more expressive) | |
# different to (A ConvNet for the 2020s) | |
# no ELU activation (slow to compute) | |
# no transpose/layernorm (slow to compute) | |
# less repeats (exploding parameter count, slow to compute) | |
# no stochastic depth (diminishing returns with less layers) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment