This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class _Transition(nn.Sequential): | |
def __init__(self, num_input_features, num_output_features): | |
super(_Transition, self).__init__() | |
self.add_module('norm', nn.BatchNorm2d(num_input_features)) | |
self.add_module('relu', nn.ReLU(inplace=True)) | |
self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, | |
kernel_size=1, stride=1, bias=False)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DenseNet(nn.Module): | |
r"""Densenet-BC model class, based on | |
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` | |
Args: | |
growth_rate (int) - how many filters to add each layer (`k` in paper) | |
block_config (list of 3 or 4 ints) - how many layers in each pooling block | |
num_init_features (int) - the number of filters to learn in the first convolution layer | |
bn_size (int) - multiplicative factor for number of bottle neck layers | |
(i.e. bn_size * k features in the bottleneck layer) | |
drop_rate (float) - dropout rate after each dense layer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class _DenseBlock(nn.Module): | |
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): | |
super(_DenseBlock, self).__init__() | |
for i in range(num_layers): | |
layer = _DenseLayer( | |
num_input_features + i * growth_rate, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _bn_function_factory(norm, relu, conv): | |
def bn_function(*inputs): | |
concated_features = torch.cat(inputs, 1) | |
bottleneck_output = conv(relu(norm(concated_features))) | |
return bottleneck_output | |
return bn_function |