Last active
November 4, 2017 16:33
-
-
Save c0nn3r/d0fcb7edb57c405cb707afeca28a8329 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.utils.model_zoo as model_zoo | |
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet | |
model_urls = { | |
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', | |
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', | |
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', | |
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', | |
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', | |
} | |
class BasicBlockFeatures(BasicBlock): | |
def forward(self, x): | |
if isinstance(x, tuple): | |
x = x[0] | |
residual = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
conv2_rep = out | |
out = self.bn2(out) | |
if self.downsample is not None: | |
residual = self.downsample(x) | |
out += residual | |
out = self.relu(out) | |
return out, conv2_rep | |
class BottleneckFeatures(Bottleneck): | |
''' | |
A Bottleneck that returns its last conv layer features. | |
''' | |
def forward(self, x): | |
if isinstance(x, tuple): | |
x = x[0] | |
residual = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
out = self.relu(out) | |
out = self.conv3(out) | |
conv3_rep = out | |
out = self.bn3(out) | |
if self.downsample is not None: | |
residual = self.downsample(x) | |
out += residual | |
out = self.relu(out) | |
return out, conv3_rep | |
class ResNetFeatures(ResNet): | |
''' | |
A ResNet that returns features instead of classification. | |
''' | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.relu(x) | |
x = self.maxpool(x) | |
x, c2 = self.layer1(x) | |
x, c3 = self.layer2(x) | |
x, c4 = self.layer3(x) | |
x, c5 = self.layer4(x) | |
return c2, c3, c4, c5 | |
def resnet50_features(pretrained=False, **kwargs): | |
"""Constructs a ResNet-50 model. | |
Args: | |
pretrained (bool): If True, returns a model pre-trained on ImageNet | |
""" | |
model = ResNetFeatures(BottleneckFeatures, [3, 4, 6, 3], **kwargs) | |
if pretrained: | |
model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) | |
return model | |
def resnet101_features(pretrained=False, **kwargs): | |
"""Constructs a ResNet-101 model. | |
Args: | |
pretrained (bool): If True, returns a model pre-trained on ImageNet | |
""" | |
model = ResNetFeatures(BottleneckFeatures, [3, 4, 23, 3], **kwargs) | |
if pretrained: | |
model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment