Created
July 3, 2019 14:54
-
-
Save chenyaofo/1e8467caeeeda1182c17ca5978618185 to your computer and use it in GitHub Desktop.
The core implementation of "The Shallow End: Empowering Shallower Deep-Convolutional Networks through Auxiliary Outputs"
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import typing | |
import functools | |
import torch.nn as nn | |
def intermediate_output_hook(module, input, output, intermediate_output_store: list): | |
intermediate_output_store.append(output) | |
def _check_entrypoints(backbone, entrypoints): | |
complete_entrypoints = set([name for name, _ in backbone.named_modules()]) | |
expected_entrypoints = set(entrypoints) | |
if not expected_entrypoints.issubset(complete_entrypoints): | |
raise ValueError("The entrypoints({}) do not exist in backbone.".format( | |
",".join(expected_entrypoints - complete_entrypoints) | |
)) | |
class AuxNetHead(nn.Module): | |
def __init__(self, in_features, out_features): | |
super(AuxNetHead, self).__init__() | |
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1)) | |
self.fc = nn.Linear(in_features, out_features) | |
def forward(self, input): | |
out = self.adaptive_pool(input) | |
out = out.view(out.size(0), -1) | |
out = self.fc(out) | |
return out | |
class AuxNet(nn.Module): | |
def __init__(self, backbone: nn.Module, | |
entrypoints_with_channels: typing.Dict[str, int], | |
num_classes, Classifier=AuxNetHead): | |
super(AuxNet, self).__init__() | |
self.backbone = backbone | |
self.entrypoints_with_channels = entrypoints_with_channels | |
self.num_classes = num_classes | |
self.Classifier = Classifier | |
self._aux_classifiers = nn.ModuleList() | |
self._intermediate_outputs = [] | |
self._hooks = [] | |
_check_entrypoints(self.backbone, self.entrypoints_with_channels.keys()) | |
self._register_intermediate_hooks() | |
self._create_aux_classifiers() | |
def _register_intermediate_hooks(self): | |
for name, module in self.backbone.named_modules(): | |
if name in self.entrypoints_with_channels.keys(): | |
self._hooks.append( | |
module.register_forward_hook( | |
functools.partial(intermediate_output_hook, | |
intermediate_output_store=self._intermediate_outputs) | |
) | |
) | |
def _create_aux_classifiers(self): | |
for name, channels in self.entrypoints_with_channels.items(): | |
self._aux_classifiers.append( | |
self.Classifier(channels, self.num_classes) | |
) | |
def _clean_intermediate_outputs(self): | |
self._intermediate_outputs.clear() | |
def _remove_hooks(self): | |
for hook in self._hooks: | |
hook.remove() | |
def forward(self, *args, **kwargs): | |
self._clean_intermediate_outputs() | |
output = self.backbone(*args, **kwargs) | |
aux_outputs = [classifier(feature) for classifier, feature in | |
zip(self._aux_classifiers, self._intermediate_outputs)] | |
return [*aux_outputs, output] | |
def autoconfig_auxnet(backbone: nn.Module, | |
entrypoints: typing.Iterable[str], | |
Classifier=AuxNetHead, | |
test_size=(1, 3, 224, 224)): | |
training = backbone.training | |
backbone.eval() | |
# | |
intermediate_outputs = [] | |
hooks = [] | |
_check_entrypoints(backbone, entrypoints) | |
for name, module in backbone.named_modules(): | |
if name in entrypoints: | |
hooks.append( | |
module.register_forward_hook( | |
functools.partial( | |
intermediate_output_hook, | |
intermediate_output_store=intermediate_outputs | |
) | |
) | |
) | |
output = backbone(torch.rand(test_size)) | |
_, num_classes = output.shape | |
entrypoints_with_channels = { | |
entrypoint: feature.shape[1] for entrypoint, feature in zip(entrypoints, intermediate_outputs) | |
} | |
for hook in hooks: | |
hook.remove() | |
# | |
backbone.train(mode=training) | |
return AuxNet(backbone, entrypoints_with_channels, num_classes, Classifier=Classifier) | |
class AuxCriterion(object): | |
def __init__(self, criterion): | |
self.criterion = criterion | |
def __call__(self, inputs, targets): | |
return AuxLoss(self.criterion(input, targets) for input in inputs) | |
class AuxLoss(object): | |
def __init__(self, losses): | |
self.losses = losses | |
def backward(self): | |
for i in reversed(range(len(self.losses))): | |
self.losses[i].backward(retain_graph=i != 0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment