Skip to content

Instantly share code, notes, and snippets.

@mukulkhanna
Created June 29, 2019 04:42
Show Gist options
  • Save mukulkhanna/0783bc22b0b8f826ef92e1e7455c3075 to your computer and use it in GitHub Desktop.
Save mukulkhanna/0783bc22b0b8f826ef92e1e7455c3075 to your computer and use it in GitHub Desktop.
class _DenseBlock(nn.Module):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate
)
self.add_module('denselayer%d' % (i + 1), layer)
def forward(self, init_features):
features = [init_features]
for name, layer in self.named_children():
new_features = layer(*features)
features.append(new_features)
return torch.cat(features, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment