Last active
June 29, 2019 04:35
-
-
Save mukulkhanna/1200979cc91e54f2e638872fd4680560 to your computer and use it in GitHub Desktop.
Dense layer code from https://github.com/gpleiss/efficient_densenet_pytorch.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _bn_function_factory(norm, relu, conv): | |
def bn_function(*inputs): | |
concated_features = torch.cat(inputs, 1) | |
bottleneck_output = conv(relu(norm(concated_features))) | |
return bottleneck_output | |
return bn_function | |
class _DenseLayer(nn.Module): | |
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): | |
super(_DenseLayer, self).__init__() | |
self.add_module('norm1', nn.BatchNorm2d(num_input_features)), | |
self.add_module('relu1', nn.ReLU(inplace=True)), | |
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate, | |
kernel_size=1, stride=1, bias=False)), | |
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), | |
self.add_module('relu2', nn.ReLU(inplace=True)), | |
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, | |
kernel_size=3, stride=1, padding=1, bias=False)), | |
self.drop_rate = drop_rate | |
def forward(self, *prev_features): | |
bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) | |
bottleneck_output = bn_function(*prev_features) | |
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) | |
if self.drop_rate > 0: | |
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) | |
return new_features |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment