Skip to content

Instantly share code, notes, and snippets.

@ajbrock
Last active July 25, 2019 10:27
Show Gist options
  • Save ajbrock/4ec90294edbe77bf8feae89334e45422 to your computer and use it in GitHub Desktop.
Save ajbrock/4ec90294edbe77bf8feae89334e45422 to your computer and use it in GitHub Desktop.
###
# Situationally faster dilated convolutions through subpixel reshapes
# A Brock, 2016
#
# Script adapted from https://github.com/soumith/convnet-benchmarks/blob/master/theano/pylearn2_benchmark.py by Jan Schluter.
#
# Outputs of this script from my tests on a GTX 980 are available here: http://pastebin.com/JRBY4Qnf
#
# Outputs of this script from my tests on a Titan X are available here: http://pastebin.com/0zJ8Uvg0
#
# This script benchmarks the forward and backward (wrt inputs and weights) passes of my implementation of dilated convolutions
# through subpixel reshapes and compares it to the lasagne implementation. The basic idea is that a subpixel downsample operation into an additional spatial dimension
# (i.e. where the elements are reordered from [1,2,3,4,5,6,7,8] to [[1,3,5,7],[2,4,6,8]]) allows one to perform dilated convolution
# (AKA atrous convolution, convolution with holes, convolution when you're on shrooms) without having to use a dilated filter.
# This implementation only works out-of-the-box when the spatial dimensions and dilated filter size line up nicely, but a little
# bit of clever zero-padding could probably overcome this.
#
# I've found that this can be significantly faster if you're dealing with high dimensional (high number of channels, high number
# of output filters) during the backward pass, particularly. I suspect the optimal solution to this issue is to integrate dilated convolutions
# into a library's im2col function, but this workaround could be useful given how computationally expensive dilated convs are.
#
# I alo suspect that, barring such an implementation, this hack could be made much faster if someone could figure out a faster
# version of my subpixel layer, which I'm pretty sure is the most expensive part (though I haven't bothered to profile it, so I
# can't say for certain. Right now it uses r^2 set_subtensor calls, which is faster than anything else I've tried,
# including two different reshape/concatenate methods, and an advanced indexing method.
## Subpixel Reshape
import os
import sys
import numpy as np
import math
import theano
import theano.tensor as T
import lasagne
from lasagne.layers.dnn import Conv2DDNNLayer as C2D
import lasagne.layers as ll
if not theano.config.device.startswith('gpu'):
import theano.sandbox.cuda
theano.sandbox.cuda.use('gpu')
theano.config.floatX = 'float32'
try:
import theano.misc.pycuda_init
import pycuda.driver
except ImportError:
print "Note: pycuda not available, no timing via CUDA events possible"
import time
pycuda = None
import theano
try:
import theano.sandbox.cuda.dnn
if not theano.sandbox.cuda.dnn.dnn_available():
del theano.sandbox.cuda.dnn
raise ImportError
except (ImportError, NameError):
print "Note: cuDNN not available"
# try:
# from pylearn2.sandbox.cuda_convnet.filter_acts import FilterActs
# except ImportError:
FilterActs = None
# print "Note: pylearn2's cuda-convnet wrapper not available"
# else:
from theano.sandbox.cuda.basic_ops import gpu_contiguous
f1=open('./benchmark_TITAN_stack.txt', 'w')
number = 10 # nb of steps in loop to average over
repeat = 1 # nb of trials to pick the minimum of
def di(ni,no,iw,ih,scale):
return {
'ni': ni,
'no': no,
'kw': 3,
'kh': 3,
'iw': iw,
'ih': ih,
'bs': 128,
'dw': 1,
'dh': 1,
'scale':scale,
}
runs = [di(3,128,64,64,2),di(3,128,32,32,2),di(3,128,16,16,2),
di(3,128,64,64,4),di(3,128,32,32,4),di(3,128,16,16,4),
di(128,128,64,64,2),di(128,128,32,32,2),di(128,128,16,16,2),
di(128,256,64,64,2),di(128,256,32,32,2),di(128,256,16,16,2),
di(256,256,64,64,2),di(256,256,32,32,2),di(256,256,16,16,2),
di(128,256,64,64,4),di(128,256,32,32,4),di(128,256,16,16,4),
di(256,512,32,32,2),di(256,512,16,16,2),di(256,512,8,8,2),
di(512,512,16,16,2),di(512,512,8,8,2),di(512,512,4,4,2),
di(512,1024,16,16,2),di(512,1024,8,8,2),di(512,1024,4,4,2)]
def time_run(fn):
times = []
fn() # warm-up call, not timed
if pycuda:
theano.sandbox.cuda.synchronize()
start = pycuda.driver.Event()
end = pycuda.driver.Event()
for _ in range(repeat):
start.record()
for _ in range(number):
fn()
end.record()
end.synchronize()
times.append(start.time_till(end) / 1e3 / number)
else:
for _ in range(repeat):
theano.sandbox.cuda.synchronize()
start = time.time()
for _ in range(number):
fn()
theano.sandbox.cuda.synchronize()
times.append((time.time() - start) / number)
return min(times)
def print_graph(fn):
if int(os.environ.get('PRINT_GRAPH', 0)):
# debugprint of graph (in blue text)
print '\033[1;34m'
theano.printing.debugprint(fn)
print '\033[1;m'
def benchmark_three_ways(name, sharedX, sharedY, sharedW, X, Y, gW, gX, mode=None):
# benchmark fprop
try:
fprop = theano.function([], [],
givens=[(X, sharedX)],
updates=[(sharedY, Y)],
mode=mode,
name=name + " fprop")
tm = time_run(fprop)
print '{: <50} ==> {: <13} ==> {: >7}'.format(name, 'fprop', int(tm*1000))
print >>f1, '{: <50} ==> {: <13} ==> {: >7}'.format(name, 'fprop', int(tm*1000))
# f1.write('{: <50} ==> {: <13} ==> {: >7}'.format(name, 'fprop', int(tm*1000)))
print_graph(fprop)
del fprop
except Exception, e:
print name, 'fprop: FAILED', str(e).split('\n', 1)[0]
# benchmark bprop wrt input
try:
bprop = theano.function([], [],
# the nvidia wrapper need this (in fact could be optional for subsample==(1, 1)
givens=[(X, sharedX)],
updates=[(sharedX, gX)],
mode=mode,
name=name + " bprop inputs")
tm = time_run(bprop)
print '{: <50} ==> {: <13} ==> {: >7}'.format(name, 'bprop inputs', int(tm*1000))
print >>f1, '{: <50} ==> {: <13} ==> {: >7}'.format(name, 'bprop inputs', int(tm*1000))
# f1.write('{: <50} ==> {: <13} ==> {: >7}'.format(name, 'bprop inputs', int(tm*1000)))
print_graph(bprop)
del bprop
except Exception, e:
print name, 'bprop inputs: FAILED', str(e).split('\n', 1)[0]
# benchmark bprop wrt weights
try:
bprop = theano.function([], [],
givens=[(X, sharedX)],
updates=[(sharedW, gW)],
mode=mode,
name=name + " bprop weights")
tm = time_run(bprop)
print '{: <50} ==> {: <13} ==> {: >7}'.format(name, 'bprop weights', int(tm*1000))
print >>f1, '{: <50} ==> {: <13} ==> {: >7}'.format(name, 'bprop weights', int(tm*1000))
# f1.write('{: <50} ==> {: <13} ==> {: >7}'.format(name, 'bprop weights', int(tm*1000)))
print_graph(bprop)
del bprop
except Exception, e:
print name, 'bprop weights: FAILED', str(e).split('\n', 1)[0]
print ''
def parse_custom_config(s):
# parses a custom configuration string of the format:
# iAxBxC,kDxExF,bG,sHxJ where A: input channels, B: input width, C: input height,
# D: output channels, E: kernel width, F: kernel height, G: batchsize,
# H: horizontal stride, J: vertical stride (with G, H, J being optional)
run = {'bs': 128, 'dw': 1, 'dh': 1}
defs = {'i': ['ni', 'iw', 'ih'],
'k': ['no', 'kw', 'kh'],
'b': ['bs'],
's': ['dw', 'dh']}
for part in s.split(','):
p, args = part[0], map(int, part[1:].split('x'))
run.update(zip(defs[p], args))
return run
if len(sys.argv) > 1:
# allow specifying the runs on command line, 1-indexed (i.e., 1 2 5)
runs = [runs[int(r) - 1] for r in sys.argv[1:] if r[0] != 'i']
# allow specifying custom configurations on command line (e.g., i3x80x15,k32x3x7,b256)
runs.extend([parse_custom_config(r) for r in sys.argv[1:] if r[0] == 'i'])
class SubpixelLayer(lasagne.layers.Layer):
def __init__(self, incoming,r,c, **kwargs):
super(SubpixelLayer, self).__init__(incoming, **kwargs)
self.r=r
self.c=c
def get_output_shape_for(self, input_shape):
return (input_shape[0],self.c,self.r*input_shape[2],self.r*input_shape[3])
def get_output_for(self, input, deterministic=False, **kwargs):
out = T.zeros((input.shape[0],self.output_shape[1],self.output_shape[2],self.output_shape[3]))
# im = np.zeros(input.shape[0],self.output_shape[1],self.output_shape[2],self.output_shape[3],dtype=theano.config.floatX)
# out=theano.shared(im)
for x in xrange(self.r): # loop across all feature maps belonging to this channel
for y in xrange(self.r):
out=T.set_subtensor(out[:,:,x::self.r,y::self.r],input[:,self.r*x+y::self.r*self.r,:,:])
return out
class SubpixelDownsampleLayer(lasagne.layers.Layer):
def __init__(self, incoming,r, **kwargs):
super(SubpixelDownsampleLayer, self).__init__(incoming, **kwargs)
self.r=r
def get_output_shape_for(self, input_shape):
return (input_shape[0],input_shape[1],self.r**2,input_shape[2]//self.r,input_shape[3]//self.r)
def get_output_for(self, input, deterministic=False, **kwargs):
out = T.zeros((input.shape[0],self.output_shape[1],self.output_shape[2],self.output_shape[3],self.output_shape[4]))
# im = np.zeros((input.shape[0],self.output_shape[1],self.output_shape[2],self.output_shape[3],self.output_shape[4]),dtype=theano.config.floatX)
# out = theano.shared(im)
for x in xrange(self.r):
for y in xrange(self.r):
out=T.set_subtensor(out[:,:,self.r*x+y,:,:],input[:,:,x::self.r,y::self.r])
return out
class Subpixel3DLayer(lasagne.layers.Layer):
def __init__(self, incoming,r,c, **kwargs):
super(Subpixel3DLayer, self).__init__(incoming, **kwargs)
self.r=r
self.c=c
def get_output_shape_for(self, input_shape):
return (input_shape[0],self.c,self.r*input_shape[3],self.r*input_shape[4])
def get_output_for(self, input, deterministic=False, **kwargs):
out = T.zeros((input.shape[0],self.output_shape[1],self.output_shape[2],self.output_shape[3]))
# im = np.zeros((input.shape[0],self.output_shape[1],self.output_shape[2],self.output_shape[3]),dtype=theano.config.floatX)
# out = theano.shared(im)
# input =T.reshape(input,(input.shape[0],input.shape[1]*input.shape[2],input.shape[3],input.shape[4]))
for x in xrange(self.r): # loop across all feature maps belonging to this channel
for y in xrange(self.r):
out=T.set_subtensor(out[:,:,x::self.r,y::self.r],input[:,:,self.r*x+y,:,:])#input[:,self.r*x+y::self.r*self.r,:,:])
return out
# Similar to the 3D subpixel layer, except with the extra dims stacked into the batch dimension
class SubpixelBatchLayer(lasagne.layers.Layer):
def __init__(self, incoming,r, **kwargs):
super(SubpixelBatchLayer, self).__init__(incoming, **kwargs)
self.r=r
def get_output_shape_for(self, input_shape):
return (input_shape[0]//self.r**2,self.input_shape[1],self.r*input_shape[2],self.r*input_shape[3])
def get_output_for(self, input, deterministic=False, **kwargs):
out = T.zeros((input.shape[0]//self.r**2,self.output_shape[1],self.output_shape[2],self.output_shape[3]))
batch_size = input.shape[0]//self.r**2
for x in xrange(self.r): # loop across all feature maps belonging to this channel
for y in xrange(self.r):
out=T.set_subtensor(out[:,:,x::self.r,y::self.r],input[batch_size*(self.r*x+y):batch_size*(self.r*x+y+1),:,:,:])
return out
class SubpixelBatchDSL(lasagne.layers.Layer):
def __init__(self, incoming,r, **kwargs):
super(SubpixelBatchDSL, self).__init__(incoming, **kwargs)
self.r=r
def get_output_shape_for(self, input_shape):
return (input_shape[0]*self.r**2,input_shape[1],input_shape[2]//self.r,input_shape[3]//self.r)
def get_output_for(self, input, deterministic=False, **kwargs):
out = T.zeros((input.shape[0]*self.r**2,self.output_shape[1],self.output_shape[2],self.output_shape[3]))
batch_size = input.shape[0]
for x in xrange(self.r): # loop across all feature maps belonging to this channel
for y in xrange(self.r):
out=T.set_subtensor(out[batch_size*(self.r*x+y):batch_size*(self.r*x+y+1),:,:,:],input[:,:,x::self.r,y::self.r])
return out
class SubpixelBatchStackDSL(lasagne.layers.Layer):
def __init__(self, incoming,r, **kwargs):
super(SubpixelBatchStackDSL, self).__init__(incoming, **kwargs)
self.r=r
def get_output_shape_for(self, input_shape):
return (input_shape[0]*self.r**2,input_shape[1],input_shape[2]//self.r,input_shape[3]//self.r)
def get_output_for(self, input, deterministic=False, **kwargs):
return T.concatenate([input[:,:,x::self.r,y::self.r] for x in xrange(self.r) for y in xrange(self.r)])
# Dilated conv2d layer with theano's im2col implementation of filter dilation
class IM2COLDCD(lasagne.layers.conv.BaseConvLayer):
def __init__(self, incoming, num_filters, filter_size,dilation, stride=(1, 1),
pad=0, untie_biases=False,
W=lasagne.init.GlorotUniform(), b=lasagne.init.Constant(0.),
nonlinearity=lasagne.nonlinearities.rectify, flip_filters=False,
convolution=T.nnet.conv2d, **kwargs):
self.dilation=dilation
super(IM2COLDCD, self).__init__(incoming, num_filters, filter_size,
stride, pad, untie_biases, W, b,
nonlinearity, flip_filters, n=2,
**kwargs)
self.convolution = convolution
def get_W_shape(self):
num_input_channels = self.input_shape[1]
# first two sizes are swapped compared to a forward convolution
return (self.num_filters,num_input_channels) + (self.filter_size[0]+self.dilation[0],)+(self.filter_size[1]+self.dilation[1],)
def convolve(self, input, **kwargs):
border_mode = 'half' if self.pad == 'same' else self.pad
conved = self.convolution(input, self.W,
self.input_shape, self.get_W_shape(),
subsample=self.stride,
border_mode=border_mode,
filter_flip=self.flip_filters,
filter_dilation=self.dilation)
return conved
# allow specifying benchmarks to skip via a SKIP environment variable
skip_tests = os.environ.get("SKIP", '').lower().split(',')
for run in runs:
# params for run:
# (input channels, output channels, kernel width, kernel height, batchsize, image width, image height, horizontal stride, vertical stride)
ni, no, kw, kh, bs, iw, ih, dw, dh = run['ni'], run['no'], run['kw'], run['kh'], run['bs'], run['iw'], run['ih'], run['dw'], run['dh']
print 'CONFIG: input =', ni, 'x', iw, 'x', ih, '* ker =', ni, 'x', no, 'x', kw, 'x', kh, '( bs =', bs, ', stride =', dw, 'scale =',run['scale'],')'
print >>f1, 'CONFIG: input =', ni, 'x', iw, 'x', ih, '* ker =', ni, 'x', no, 'x', kw, 'x', kh, '( bs =', bs, ', stride =', dw, 'scale =',run['scale'],')'
# f1.write(('CONFIG: input =', ni, 'x', iw, 'x', ih, '* ker =', ni, 'x', no, 'x', kw, 'x', kh, '( bs =', bs, ', stride =', dw, 'scale =',run['scale'],')'))
ops = 2 # ops per point
mode = theano.compile.get_default_mode()
# benchmark Theano legacy convolution
# Mimic THEANO_FLAGS=optimizer_excluding=conv_gemm:conv_dnn
input_shape = (bs, ni, ih, iw)
filter_shape = (no, ni, kh, kw)
try:
sharedX = theano.shared(np.random.randn(*input_shape).astype('float32'), name='sharedX')
sharedY = theano.shared(np.random.randn(bs, no, (ih-kh)/dh+1, (iw-kw)/dw+1).astype('float32'), name='sharedY')
sharedW = theano.shared(np.random.randn(*filter_shape).astype('float32'), name='sharedW')
except MemoryError, e:
print "SKIPPING config due to the memory error below"
print e
continue
X = theano.tensor.tensor4('X')
scale = run['scale']
nf = no
l_in = ll.InputLayer(input_shape)
s0 = SubpixelDownsampleLayer(l_in,r=scale)
sc = lasagne.layers.dnn.Conv3DDNNLayer(incoming = s0,
num_filters = nf,
filter_size = [1,3,3],
stride = [1,1,1],
pad = [0,1,1],
W = sharedW.dimshuffle(0,1,'x',2,3),
b = None,
nonlinearity = None,
name = 'subpixel_conv')
s1 = Subpixel3DLayer(sc,r=scale,c=nf)
d1 = lasagne.layers.DilatedConv2DLayer(incoming = lasagne.layers.PadLayer(incoming = l_in, width=(scale,scale)),
num_filters = nf,
filter_size = [3,3],
dilation=(scale,scale),
W = sharedW.dimshuffle(1,0,2,3),
b = None,
nonlinearity = None,
name = 'dilated_conv')
dcd1 = IM2COLDCD(incoming = l_in,
num_filters = nf,
filter_size = [3,3],
dilation=(scale,scale),
W = sharedW,
b = None,
pad = 'same',
nonlinearity = None,
name = 'i2mcol')
BatchStackDSL = SubpixelBatchStackDSL(l_in,r=scale)
BatchStackConv = lasagne.layers.dnn.Conv2DDNNLayer(incoming = BatchStackDSL,
num_filters = nf,
filter_size = [3,3],
stride = [1,1],
pad = 'same',
W = sharedW,
b = None,
nonlinearity = None,
name = 'BatchStackConv')
BatchStackOut = SubpixelBatchLayer(BatchStackConv,r=scale)
BatchDSL = SubpixelBatchDSL(l_in,r=scale)
BatchConv = lasagne.layers.dnn.Conv2DDNNLayer(incoming = BatchDSL,
num_filters = nf,
filter_size = [3,3],
stride = [1,1],
pad = 'same',
W = sharedW,
b = None,
nonlinearity = None,
name = 'BatchConv')
BatchOut = SubpixelBatchLayer(BatchConv,r=scale)
Y = ll.get_output(s1,X)
gW = theano.grad(None, wrt=sharedW, known_grads={Y: sharedY})
gX = theano.grad(None, wrt=X, known_grads={Y: sharedY})
benchmark_three_ways(
'Subpixel DCD',
sharedX, sharedY, sharedW, X, Y, gW, gX)
Y = ll.get_output(d1,X)
gW = theano.grad(None, wrt=sharedW, known_grads={Y: sharedY})
gX = theano.grad(None, wrt=X, known_grads={Y: sharedY})
benchmark_three_ways(
'Normal DCD',
sharedX, sharedY, sharedW, X, Y, gW, gX)
Y = ll.get_output(dcd1,X)
gW = theano.grad(None, wrt=sharedW, known_grads={Y: sharedY})
gX = theano.grad(None, wrt=X, known_grads={Y: sharedY})
benchmark_three_ways(
'IM2COL DCD',
sharedX, sharedY, sharedW, X, Y, gW, gX)
Y = ll.get_output(BatchOut,X)
gW = theano.grad(None, wrt=sharedW, known_grads={Y: sharedY})
gX = theano.grad(None, wrt=X, known_grads={Y: sharedY})
benchmark_three_ways(
'Batchroll DCD',
sharedX, sharedY, sharedW, X, Y, gW, gX)
Y = ll.get_output(BatchStackOut,X)
gW = theano.grad(None, wrt=sharedW, known_grads={Y: sharedY})
gX = theano.grad(None, wrt=X, known_grads={Y: sharedY})
benchmark_three_ways(
'BatchStack DCD',
sharedX, sharedY, sharedW, X, Y, gW, gX)
del sharedX
del sharedY
del sharedW
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment