Last active
July 25, 2019 10:27
-
-
Save ajbrock/4ec90294edbe77bf8feae89334e45422 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### | |
# Situationally faster dilated convolutions through subpixel reshapes | |
# A Brock, 2016 | |
# | |
# Script adapted from https://github.com/soumith/convnet-benchmarks/blob/master/theano/pylearn2_benchmark.py by Jan Schluter. | |
# | |
# Outputs of this script from my tests on a GTX 980 are available here: http://pastebin.com/JRBY4Qnf | |
# | |
# Outputs of this script from my tests on a Titan X are available here: http://pastebin.com/0zJ8Uvg0 | |
# | |
# This script benchmarks the forward and backward (wrt inputs and weights) passes of my implementation of dilated convolutions | |
# through subpixel reshapes and compares it to the lasagne implementation. The basic idea is that a subpixel downsample operation into an additional spatial dimension | |
# (i.e. where the elements are reordered from [1,2,3,4,5,6,7,8] to [[1,3,5,7],[2,4,6,8]]) allows one to perform dilated convolution | |
# (AKA atrous convolution, convolution with holes, convolution when you're on shrooms) without having to use a dilated filter. | |
# This implementation only works out-of-the-box when the spatial dimensions and dilated filter size line up nicely, but a little | |
# bit of clever zero-padding could probably overcome this. | |
# | |
# I've found that this can be significantly faster if you're dealing with high dimensional (high number of channels, high number | |
# of output filters) during the backward pass, particularly. I suspect the optimal solution to this issue is to integrate dilated convolutions | |
# into a library's im2col function, but this workaround could be useful given how computationally expensive dilated convs are. | |
# | |
# I alo suspect that, barring such an implementation, this hack could be made much faster if someone could figure out a faster | |
# version of my subpixel layer, which I'm pretty sure is the most expensive part (though I haven't bothered to profile it, so I | |
# can't say for certain. Right now it uses r^2 set_subtensor calls, which is faster than anything else I've tried, | |
# including two different reshape/concatenate methods, and an advanced indexing method. | |
## Subpixel Reshape | |
import os | |
import sys | |
import numpy as np | |
import math | |
import theano | |
import theano.tensor as T | |
import lasagne | |
from lasagne.layers.dnn import Conv2DDNNLayer as C2D | |
import lasagne.layers as ll | |
if not theano.config.device.startswith('gpu'): | |
import theano.sandbox.cuda | |
theano.sandbox.cuda.use('gpu') | |
theano.config.floatX = 'float32' | |
try: | |
import theano.misc.pycuda_init | |
import pycuda.driver | |
except ImportError: | |
print "Note: pycuda not available, no timing via CUDA events possible" | |
import time | |
pycuda = None | |
import theano | |
try: | |
import theano.sandbox.cuda.dnn | |
if not theano.sandbox.cuda.dnn.dnn_available(): | |
del theano.sandbox.cuda.dnn | |
raise ImportError | |
except (ImportError, NameError): | |
print "Note: cuDNN not available" | |
# try: | |
# from pylearn2.sandbox.cuda_convnet.filter_acts import FilterActs | |
# except ImportError: | |
FilterActs = None | |
# print "Note: pylearn2's cuda-convnet wrapper not available" | |
# else: | |
from theano.sandbox.cuda.basic_ops import gpu_contiguous | |
f1=open('./benchmark_TITAN_stack.txt', 'w') | |
number = 10 # nb of steps in loop to average over | |
repeat = 1 # nb of trials to pick the minimum of | |
def di(ni,no,iw,ih,scale): | |
return { | |
'ni': ni, | |
'no': no, | |
'kw': 3, | |
'kh': 3, | |
'iw': iw, | |
'ih': ih, | |
'bs': 128, | |
'dw': 1, | |
'dh': 1, | |
'scale':scale, | |
} | |
runs = [di(3,128,64,64,2),di(3,128,32,32,2),di(3,128,16,16,2), | |
di(3,128,64,64,4),di(3,128,32,32,4),di(3,128,16,16,4), | |
di(128,128,64,64,2),di(128,128,32,32,2),di(128,128,16,16,2), | |
di(128,256,64,64,2),di(128,256,32,32,2),di(128,256,16,16,2), | |
di(256,256,64,64,2),di(256,256,32,32,2),di(256,256,16,16,2), | |
di(128,256,64,64,4),di(128,256,32,32,4),di(128,256,16,16,4), | |
di(256,512,32,32,2),di(256,512,16,16,2),di(256,512,8,8,2), | |
di(512,512,16,16,2),di(512,512,8,8,2),di(512,512,4,4,2), | |
di(512,1024,16,16,2),di(512,1024,8,8,2),di(512,1024,4,4,2)] | |
def time_run(fn): | |
times = [] | |
fn() # warm-up call, not timed | |
if pycuda: | |
theano.sandbox.cuda.synchronize() | |
start = pycuda.driver.Event() | |
end = pycuda.driver.Event() | |
for _ in range(repeat): | |
start.record() | |
for _ in range(number): | |
fn() | |
end.record() | |
end.synchronize() | |
times.append(start.time_till(end) / 1e3 / number) | |
else: | |
for _ in range(repeat): | |
theano.sandbox.cuda.synchronize() | |
start = time.time() | |
for _ in range(number): | |
fn() | |
theano.sandbox.cuda.synchronize() | |
times.append((time.time() - start) / number) | |
return min(times) | |
def print_graph(fn): | |
if int(os.environ.get('PRINT_GRAPH', 0)): | |
# debugprint of graph (in blue text) | |
print '\033[1;34m' | |
theano.printing.debugprint(fn) | |
print '\033[1;m' | |
def benchmark_three_ways(name, sharedX, sharedY, sharedW, X, Y, gW, gX, mode=None): | |
# benchmark fprop | |
try: | |
fprop = theano.function([], [], | |
givens=[(X, sharedX)], | |
updates=[(sharedY, Y)], | |
mode=mode, | |
name=name + " fprop") | |
tm = time_run(fprop) | |
print '{: <50} ==> {: <13} ==> {: >7}'.format(name, 'fprop', int(tm*1000)) | |
print >>f1, '{: <50} ==> {: <13} ==> {: >7}'.format(name, 'fprop', int(tm*1000)) | |
# f1.write('{: <50} ==> {: <13} ==> {: >7}'.format(name, 'fprop', int(tm*1000))) | |
print_graph(fprop) | |
del fprop | |
except Exception, e: | |
print name, 'fprop: FAILED', str(e).split('\n', 1)[0] | |
# benchmark bprop wrt input | |
try: | |
bprop = theano.function([], [], | |
# the nvidia wrapper need this (in fact could be optional for subsample==(1, 1) | |
givens=[(X, sharedX)], | |
updates=[(sharedX, gX)], | |
mode=mode, | |
name=name + " bprop inputs") | |
tm = time_run(bprop) | |
print '{: <50} ==> {: <13} ==> {: >7}'.format(name, 'bprop inputs', int(tm*1000)) | |
print >>f1, '{: <50} ==> {: <13} ==> {: >7}'.format(name, 'bprop inputs', int(tm*1000)) | |
# f1.write('{: <50} ==> {: <13} ==> {: >7}'.format(name, 'bprop inputs', int(tm*1000))) | |
print_graph(bprop) | |
del bprop | |
except Exception, e: | |
print name, 'bprop inputs: FAILED', str(e).split('\n', 1)[0] | |
# benchmark bprop wrt weights | |
try: | |
bprop = theano.function([], [], | |
givens=[(X, sharedX)], | |
updates=[(sharedW, gW)], | |
mode=mode, | |
name=name + " bprop weights") | |
tm = time_run(bprop) | |
print '{: <50} ==> {: <13} ==> {: >7}'.format(name, 'bprop weights', int(tm*1000)) | |
print >>f1, '{: <50} ==> {: <13} ==> {: >7}'.format(name, 'bprop weights', int(tm*1000)) | |
# f1.write('{: <50} ==> {: <13} ==> {: >7}'.format(name, 'bprop weights', int(tm*1000))) | |
print_graph(bprop) | |
del bprop | |
except Exception, e: | |
print name, 'bprop weights: FAILED', str(e).split('\n', 1)[0] | |
print '' | |
def parse_custom_config(s): | |
# parses a custom configuration string of the format: | |
# iAxBxC,kDxExF,bG,sHxJ where A: input channels, B: input width, C: input height, | |
# D: output channels, E: kernel width, F: kernel height, G: batchsize, | |
# H: horizontal stride, J: vertical stride (with G, H, J being optional) | |
run = {'bs': 128, 'dw': 1, 'dh': 1} | |
defs = {'i': ['ni', 'iw', 'ih'], | |
'k': ['no', 'kw', 'kh'], | |
'b': ['bs'], | |
's': ['dw', 'dh']} | |
for part in s.split(','): | |
p, args = part[0], map(int, part[1:].split('x')) | |
run.update(zip(defs[p], args)) | |
return run | |
if len(sys.argv) > 1: | |
# allow specifying the runs on command line, 1-indexed (i.e., 1 2 5) | |
runs = [runs[int(r) - 1] for r in sys.argv[1:] if r[0] != 'i'] | |
# allow specifying custom configurations on command line (e.g., i3x80x15,k32x3x7,b256) | |
runs.extend([parse_custom_config(r) for r in sys.argv[1:] if r[0] == 'i']) | |
class SubpixelLayer(lasagne.layers.Layer): | |
def __init__(self, incoming,r,c, **kwargs): | |
super(SubpixelLayer, self).__init__(incoming, **kwargs) | |
self.r=r | |
self.c=c | |
def get_output_shape_for(self, input_shape): | |
return (input_shape[0],self.c,self.r*input_shape[2],self.r*input_shape[3]) | |
def get_output_for(self, input, deterministic=False, **kwargs): | |
out = T.zeros((input.shape[0],self.output_shape[1],self.output_shape[2],self.output_shape[3])) | |
# im = np.zeros(input.shape[0],self.output_shape[1],self.output_shape[2],self.output_shape[3],dtype=theano.config.floatX) | |
# out=theano.shared(im) | |
for x in xrange(self.r): # loop across all feature maps belonging to this channel | |
for y in xrange(self.r): | |
out=T.set_subtensor(out[:,:,x::self.r,y::self.r],input[:,self.r*x+y::self.r*self.r,:,:]) | |
return out | |
class SubpixelDownsampleLayer(lasagne.layers.Layer): | |
def __init__(self, incoming,r, **kwargs): | |
super(SubpixelDownsampleLayer, self).__init__(incoming, **kwargs) | |
self.r=r | |
def get_output_shape_for(self, input_shape): | |
return (input_shape[0],input_shape[1],self.r**2,input_shape[2]//self.r,input_shape[3]//self.r) | |
def get_output_for(self, input, deterministic=False, **kwargs): | |
out = T.zeros((input.shape[0],self.output_shape[1],self.output_shape[2],self.output_shape[3],self.output_shape[4])) | |
# im = np.zeros((input.shape[0],self.output_shape[1],self.output_shape[2],self.output_shape[3],self.output_shape[4]),dtype=theano.config.floatX) | |
# out = theano.shared(im) | |
for x in xrange(self.r): | |
for y in xrange(self.r): | |
out=T.set_subtensor(out[:,:,self.r*x+y,:,:],input[:,:,x::self.r,y::self.r]) | |
return out | |
class Subpixel3DLayer(lasagne.layers.Layer): | |
def __init__(self, incoming,r,c, **kwargs): | |
super(Subpixel3DLayer, self).__init__(incoming, **kwargs) | |
self.r=r | |
self.c=c | |
def get_output_shape_for(self, input_shape): | |
return (input_shape[0],self.c,self.r*input_shape[3],self.r*input_shape[4]) | |
def get_output_for(self, input, deterministic=False, **kwargs): | |
out = T.zeros((input.shape[0],self.output_shape[1],self.output_shape[2],self.output_shape[3])) | |
# im = np.zeros((input.shape[0],self.output_shape[1],self.output_shape[2],self.output_shape[3]),dtype=theano.config.floatX) | |
# out = theano.shared(im) | |
# input =T.reshape(input,(input.shape[0],input.shape[1]*input.shape[2],input.shape[3],input.shape[4])) | |
for x in xrange(self.r): # loop across all feature maps belonging to this channel | |
for y in xrange(self.r): | |
out=T.set_subtensor(out[:,:,x::self.r,y::self.r],input[:,:,self.r*x+y,:,:])#input[:,self.r*x+y::self.r*self.r,:,:]) | |
return out | |
# Similar to the 3D subpixel layer, except with the extra dims stacked into the batch dimension | |
class SubpixelBatchLayer(lasagne.layers.Layer): | |
def __init__(self, incoming,r, **kwargs): | |
super(SubpixelBatchLayer, self).__init__(incoming, **kwargs) | |
self.r=r | |
def get_output_shape_for(self, input_shape): | |
return (input_shape[0]//self.r**2,self.input_shape[1],self.r*input_shape[2],self.r*input_shape[3]) | |
def get_output_for(self, input, deterministic=False, **kwargs): | |
out = T.zeros((input.shape[0]//self.r**2,self.output_shape[1],self.output_shape[2],self.output_shape[3])) | |
batch_size = input.shape[0]//self.r**2 | |
for x in xrange(self.r): # loop across all feature maps belonging to this channel | |
for y in xrange(self.r): | |
out=T.set_subtensor(out[:,:,x::self.r,y::self.r],input[batch_size*(self.r*x+y):batch_size*(self.r*x+y+1),:,:,:]) | |
return out | |
class SubpixelBatchDSL(lasagne.layers.Layer): | |
def __init__(self, incoming,r, **kwargs): | |
super(SubpixelBatchDSL, self).__init__(incoming, **kwargs) | |
self.r=r | |
def get_output_shape_for(self, input_shape): | |
return (input_shape[0]*self.r**2,input_shape[1],input_shape[2]//self.r,input_shape[3]//self.r) | |
def get_output_for(self, input, deterministic=False, **kwargs): | |
out = T.zeros((input.shape[0]*self.r**2,self.output_shape[1],self.output_shape[2],self.output_shape[3])) | |
batch_size = input.shape[0] | |
for x in xrange(self.r): # loop across all feature maps belonging to this channel | |
for y in xrange(self.r): | |
out=T.set_subtensor(out[batch_size*(self.r*x+y):batch_size*(self.r*x+y+1),:,:,:],input[:,:,x::self.r,y::self.r]) | |
return out | |
class SubpixelBatchStackDSL(lasagne.layers.Layer): | |
def __init__(self, incoming,r, **kwargs): | |
super(SubpixelBatchStackDSL, self).__init__(incoming, **kwargs) | |
self.r=r | |
def get_output_shape_for(self, input_shape): | |
return (input_shape[0]*self.r**2,input_shape[1],input_shape[2]//self.r,input_shape[3]//self.r) | |
def get_output_for(self, input, deterministic=False, **kwargs): | |
return T.concatenate([input[:,:,x::self.r,y::self.r] for x in xrange(self.r) for y in xrange(self.r)]) | |
# Dilated conv2d layer with theano's im2col implementation of filter dilation | |
class IM2COLDCD(lasagne.layers.conv.BaseConvLayer): | |
def __init__(self, incoming, num_filters, filter_size,dilation, stride=(1, 1), | |
pad=0, untie_biases=False, | |
W=lasagne.init.GlorotUniform(), b=lasagne.init.Constant(0.), | |
nonlinearity=lasagne.nonlinearities.rectify, flip_filters=False, | |
convolution=T.nnet.conv2d, **kwargs): | |
self.dilation=dilation | |
super(IM2COLDCD, self).__init__(incoming, num_filters, filter_size, | |
stride, pad, untie_biases, W, b, | |
nonlinearity, flip_filters, n=2, | |
**kwargs) | |
self.convolution = convolution | |
def get_W_shape(self): | |
num_input_channels = self.input_shape[1] | |
# first two sizes are swapped compared to a forward convolution | |
return (self.num_filters,num_input_channels) + (self.filter_size[0]+self.dilation[0],)+(self.filter_size[1]+self.dilation[1],) | |
def convolve(self, input, **kwargs): | |
border_mode = 'half' if self.pad == 'same' else self.pad | |
conved = self.convolution(input, self.W, | |
self.input_shape, self.get_W_shape(), | |
subsample=self.stride, | |
border_mode=border_mode, | |
filter_flip=self.flip_filters, | |
filter_dilation=self.dilation) | |
return conved | |
# allow specifying benchmarks to skip via a SKIP environment variable | |
skip_tests = os.environ.get("SKIP", '').lower().split(',') | |
for run in runs: | |
# params for run: | |
# (input channels, output channels, kernel width, kernel height, batchsize, image width, image height, horizontal stride, vertical stride) | |
ni, no, kw, kh, bs, iw, ih, dw, dh = run['ni'], run['no'], run['kw'], run['kh'], run['bs'], run['iw'], run['ih'], run['dw'], run['dh'] | |
print 'CONFIG: input =', ni, 'x', iw, 'x', ih, '* ker =', ni, 'x', no, 'x', kw, 'x', kh, '( bs =', bs, ', stride =', dw, 'scale =',run['scale'],')' | |
print >>f1, 'CONFIG: input =', ni, 'x', iw, 'x', ih, '* ker =', ni, 'x', no, 'x', kw, 'x', kh, '( bs =', bs, ', stride =', dw, 'scale =',run['scale'],')' | |
# f1.write(('CONFIG: input =', ni, 'x', iw, 'x', ih, '* ker =', ni, 'x', no, 'x', kw, 'x', kh, '( bs =', bs, ', stride =', dw, 'scale =',run['scale'],')')) | |
ops = 2 # ops per point | |
mode = theano.compile.get_default_mode() | |
# benchmark Theano legacy convolution | |
# Mimic THEANO_FLAGS=optimizer_excluding=conv_gemm:conv_dnn | |
input_shape = (bs, ni, ih, iw) | |
filter_shape = (no, ni, kh, kw) | |
try: | |
sharedX = theano.shared(np.random.randn(*input_shape).astype('float32'), name='sharedX') | |
sharedY = theano.shared(np.random.randn(bs, no, (ih-kh)/dh+1, (iw-kw)/dw+1).astype('float32'), name='sharedY') | |
sharedW = theano.shared(np.random.randn(*filter_shape).astype('float32'), name='sharedW') | |
except MemoryError, e: | |
print "SKIPPING config due to the memory error below" | |
print e | |
continue | |
X = theano.tensor.tensor4('X') | |
scale = run['scale'] | |
nf = no | |
l_in = ll.InputLayer(input_shape) | |
s0 = SubpixelDownsampleLayer(l_in,r=scale) | |
sc = lasagne.layers.dnn.Conv3DDNNLayer(incoming = s0, | |
num_filters = nf, | |
filter_size = [1,3,3], | |
stride = [1,1,1], | |
pad = [0,1,1], | |
W = sharedW.dimshuffle(0,1,'x',2,3), | |
b = None, | |
nonlinearity = None, | |
name = 'subpixel_conv') | |
s1 = Subpixel3DLayer(sc,r=scale,c=nf) | |
d1 = lasagne.layers.DilatedConv2DLayer(incoming = lasagne.layers.PadLayer(incoming = l_in, width=(scale,scale)), | |
num_filters = nf, | |
filter_size = [3,3], | |
dilation=(scale,scale), | |
W = sharedW.dimshuffle(1,0,2,3), | |
b = None, | |
nonlinearity = None, | |
name = 'dilated_conv') | |
dcd1 = IM2COLDCD(incoming = l_in, | |
num_filters = nf, | |
filter_size = [3,3], | |
dilation=(scale,scale), | |
W = sharedW, | |
b = None, | |
pad = 'same', | |
nonlinearity = None, | |
name = 'i2mcol') | |
BatchStackDSL = SubpixelBatchStackDSL(l_in,r=scale) | |
BatchStackConv = lasagne.layers.dnn.Conv2DDNNLayer(incoming = BatchStackDSL, | |
num_filters = nf, | |
filter_size = [3,3], | |
stride = [1,1], | |
pad = 'same', | |
W = sharedW, | |
b = None, | |
nonlinearity = None, | |
name = 'BatchStackConv') | |
BatchStackOut = SubpixelBatchLayer(BatchStackConv,r=scale) | |
BatchDSL = SubpixelBatchDSL(l_in,r=scale) | |
BatchConv = lasagne.layers.dnn.Conv2DDNNLayer(incoming = BatchDSL, | |
num_filters = nf, | |
filter_size = [3,3], | |
stride = [1,1], | |
pad = 'same', | |
W = sharedW, | |
b = None, | |
nonlinearity = None, | |
name = 'BatchConv') | |
BatchOut = SubpixelBatchLayer(BatchConv,r=scale) | |
Y = ll.get_output(s1,X) | |
gW = theano.grad(None, wrt=sharedW, known_grads={Y: sharedY}) | |
gX = theano.grad(None, wrt=X, known_grads={Y: sharedY}) | |
benchmark_three_ways( | |
'Subpixel DCD', | |
sharedX, sharedY, sharedW, X, Y, gW, gX) | |
Y = ll.get_output(d1,X) | |
gW = theano.grad(None, wrt=sharedW, known_grads={Y: sharedY}) | |
gX = theano.grad(None, wrt=X, known_grads={Y: sharedY}) | |
benchmark_three_ways( | |
'Normal DCD', | |
sharedX, sharedY, sharedW, X, Y, gW, gX) | |
Y = ll.get_output(dcd1,X) | |
gW = theano.grad(None, wrt=sharedW, known_grads={Y: sharedY}) | |
gX = theano.grad(None, wrt=X, known_grads={Y: sharedY}) | |
benchmark_three_ways( | |
'IM2COL DCD', | |
sharedX, sharedY, sharedW, X, Y, gW, gX) | |
Y = ll.get_output(BatchOut,X) | |
gW = theano.grad(None, wrt=sharedW, known_grads={Y: sharedY}) | |
gX = theano.grad(None, wrt=X, known_grads={Y: sharedY}) | |
benchmark_three_ways( | |
'Batchroll DCD', | |
sharedX, sharedY, sharedW, X, Y, gW, gX) | |
Y = ll.get_output(BatchStackOut,X) | |
gW = theano.grad(None, wrt=sharedW, known_grads={Y: sharedY}) | |
gX = theano.grad(None, wrt=X, known_grads={Y: sharedY}) | |
benchmark_three_ways( | |
'BatchStack DCD', | |
sharedX, sharedY, sharedW, X, Y, gW, gX) | |
del sharedX | |
del sharedY | |
del sharedW |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment