Last active
December 23, 2017 03:18
-
-
Save kazimuth/9c997cf74599e9b292d1eb87de799564 to your computer and use it in GitHub Desktop.
More Efficient Dense Layers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
'''More efficient dense layers for DenseNets. | |
author: James Gilles.''' | |
import torch | |
from torch import nn, autograd | |
class RepeatedConcat(nn.Module): | |
'''Hack for faster DenseNets. | |
Allocates O(n) memory for n concatenated layers, | |
instead of O(n^2). | |
''' | |
def __init__(self): | |
super().__init__() | |
# there's no better way to tell if a module should be cuda :/ | |
self.cuda_tracker = nn.Parameter(torch.Tensor([1])) | |
def reset(self, shape): | |
'''reset() should be called at the beginning of a forward pass, | |
with a shape describing the necessary size of the final output tensor. | |
shape should be (batch, channels, height, width).''' | |
self.shape = shape | |
if self.cuda_tracker.is_cuda: | |
self.data = torch.cuda.FloatTensor(*self.shape, device=self.cuda_tracker.get_device()) | |
else: | |
self.data = torch.FloatTensor(*self.shape) | |
self.n = 0 | |
self.slices = [] | |
def assign_next(self, n, var): | |
'''Concatenate another block of data. | |
'var' should be an autograd.Variable of size (batch, n, height, width)''' | |
assert n == var.shape[1] | |
self.data[:, self.n:self.n+n] = var.data | |
self.slices.append(var) | |
self.n += n | |
def get_so_far(self): | |
'''Return all of the blocks of data so far, concatenated. | |
Doesn't allocate memory, but works with autograd.''' | |
return _Fuser.apply(self.data, *self.slices) | |
def __repr__(self): | |
return 'Bundle[{}]'.format(self.shape) | |
class _Fuser(autograd.Function): | |
'''Hack for faster DenseNets.''' | |
@staticmethod | |
def forward(ctx, data, *args): | |
'''Invariant: the tensors that are passed as extra arguments *must* be slices, in order, | |
of the 'data' tensor. This way, we can "concatenate" them by simply returning the original block of data. | |
''' | |
slices = tuple(a.shape[1] for a in args) | |
ctx.slices = slices | |
return data[:, :sum(slices)] | |
@staticmethod | |
def backward(ctx, grad): | |
'''grad contains the gradients for all of our slices. | |
We can simply slice it to give the correct gradients for the input variables.''' | |
slices = ctx.slices | |
results = [None] | |
n = 0 | |
for s in slices: | |
results.append(grad[:, n:n+s]) | |
return tuple(results) | |
class DenseBlock(nn.Sequential): | |
def __init__(self, in_channels, growth_rate, bn_size, nonlin): | |
super().__init__( | |
nn.Conv2d(in_channels, bn_size * growth_rate, kernel_size=1, stride=1, bias=False), | |
nn.BatchNorm2d(bn_size * growth_rate), | |
nonlin(bn_size * growth_rate), | |
nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False), | |
nn.BatchNorm2d(growth_rate), | |
nonlin(growth_rate) | |
# you could put another BatchNorm here if you really wanted | |
) | |
class DenseLayer(nn.Module): | |
'''Similar to the _DenseLayer provided by torchvision, but with far fewer allocations.''' | |
def __init__(self, in_channels, num_layers, growth_rate, | |
bn_size=4, nonlin=lambda channels: nn.PReLU(channels)): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = in_channels + growth_rate * num_layers | |
self.growth_rate = growth_rate | |
self.concat = RepeatedConcat() | |
self.blocks = nn.ModuleList([ | |
DenseBlock(in_channels + i*growth_rate, growth_rate, bn_size, nonlin) | |
for i in range(num_layers) | |
]) | |
def forward(self, x): | |
assert x.shape[1] == self.in_channels, (x.shape[1], self.in_channels) | |
self.concat.reset((x.shape[0], self.out_channels, x.shape[2], x.shape[3])) | |
self.concat.assign_next(x.shape[1], x) | |
for i, block in enumerate(self.blocks): | |
sofar = self.concat.get_so_far() | |
output = block(sofar) | |
self.concat.assign_next(self.growth_rate, output) | |
assert self.concat.n == self.out_channels, (self.concat.n, self.out_channels) | |
return self.concat.get_so_far() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment