Skip to content

Instantly share code, notes, and snippets.

@kazimuth
Last active December 23, 2017 03:18
Show Gist options
  • Save kazimuth/9c997cf74599e9b292d1eb87de799564 to your computer and use it in GitHub Desktop.
Save kazimuth/9c997cf74599e9b292d1eb87de799564 to your computer and use it in GitHub Desktop.
More Efficient Dense Layers
'''More efficient dense layers for DenseNets.
author: James Gilles.'''
import torch
from torch import nn, autograd
class RepeatedConcat(nn.Module):
'''Hack for faster DenseNets.
Allocates O(n) memory for n concatenated layers,
instead of O(n^2).
'''
def __init__(self):
super().__init__()
# there's no better way to tell if a module should be cuda :/
self.cuda_tracker = nn.Parameter(torch.Tensor([1]))
def reset(self, shape):
'''reset() should be called at the beginning of a forward pass,
with a shape describing the necessary size of the final output tensor.
shape should be (batch, channels, height, width).'''
self.shape = shape
if self.cuda_tracker.is_cuda:
self.data = torch.cuda.FloatTensor(*self.shape, device=self.cuda_tracker.get_device())
else:
self.data = torch.FloatTensor(*self.shape)
self.n = 0
self.slices = []
def assign_next(self, n, var):
'''Concatenate another block of data.
'var' should be an autograd.Variable of size (batch, n, height, width)'''
assert n == var.shape[1]
self.data[:, self.n:self.n+n] = var.data
self.slices.append(var)
self.n += n
def get_so_far(self):
'''Return all of the blocks of data so far, concatenated.
Doesn't allocate memory, but works with autograd.'''
return _Fuser.apply(self.data, *self.slices)
def __repr__(self):
return 'Bundle[{}]'.format(self.shape)
class _Fuser(autograd.Function):
'''Hack for faster DenseNets.'''
@staticmethod
def forward(ctx, data, *args):
'''Invariant: the tensors that are passed as extra arguments *must* be slices, in order,
of the 'data' tensor. This way, we can "concatenate" them by simply returning the original block of data.
'''
slices = tuple(a.shape[1] for a in args)
ctx.slices = slices
return data[:, :sum(slices)]
@staticmethod
def backward(ctx, grad):
'''grad contains the gradients for all of our slices.
We can simply slice it to give the correct gradients for the input variables.'''
slices = ctx.slices
results = [None]
n = 0
for s in slices:
results.append(grad[:, n:n+s])
return tuple(results)
class DenseBlock(nn.Sequential):
def __init__(self, in_channels, growth_rate, bn_size, nonlin):
super().__init__(
nn.Conv2d(in_channels, bn_size * growth_rate, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(bn_size * growth_rate),
nonlin(bn_size * growth_rate),
nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(growth_rate),
nonlin(growth_rate)
# you could put another BatchNorm here if you really wanted
)
class DenseLayer(nn.Module):
'''Similar to the _DenseLayer provided by torchvision, but with far fewer allocations.'''
def __init__(self, in_channels, num_layers, growth_rate,
bn_size=4, nonlin=lambda channels: nn.PReLU(channels)):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels + growth_rate * num_layers
self.growth_rate = growth_rate
self.concat = RepeatedConcat()
self.blocks = nn.ModuleList([
DenseBlock(in_channels + i*growth_rate, growth_rate, bn_size, nonlin)
for i in range(num_layers)
])
def forward(self, x):
assert x.shape[1] == self.in_channels, (x.shape[1], self.in_channels)
self.concat.reset((x.shape[0], self.out_channels, x.shape[2], x.shape[3]))
self.concat.assign_next(x.shape[1], x)
for i, block in enumerate(self.blocks):
sofar = self.concat.get_so_far()
output = block(sofar)
self.concat.assign_next(self.growth_rate, output)
assert self.concat.n == self.out_channels, (self.concat.n, self.out_channels)
return self.concat.get_so_far()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment