Skip to content

Instantly share code, notes, and snippets.

@ajbrock
Last active March 24, 2017 08:44
Show Gist options
  • Save ajbrock/cc083da8c274ff25acdc318b3249b972 to your computer and use it in GitHub Desktop.
Save ajbrock/cc083da8c274ff25acdc318b3249b972 to your computer and use it in GitHub Desktop.
import theano
import theano.tensor as T
import lasagne
import numpy as np
import time
# Subpixel Upsample Layer using Set_subtensor
# This layer uses a set of r^2 inc_subtensor calls to reorganize the tensor in a subpixel-layer upscaling style
# as done in the ESPCN magic pony paper for super-resolution. There is almost certainly a more efficient way to do this,
# but I haven't figured it out at the moment and this seems to be fast enough.
class SubpixelLayer(lasagne.layers.Layer):
def __init__(self, incoming,r,c, **kwargs):
super(SubpixelLayer, self).__init__(incoming, **kwargs)
self.r=r
self.c=c
def get_output_shape_for(self, input_shape):
return (input_shape[0],self.c,self.r*input_shape[2],self.r*input_shape[3])
def get_output_for(self, input, deterministic=False, **kwargs):
out = T.zeros((input.shape[0],self.output_shape[1],self.output_shape[2],self.output_shape[3]))
for x in xrange(self.r): # loop across all feature maps belonging to this channel
for y in xrange(self.r):
out=T.set_subtensor(out[:,:,x::self.r,y::self.r],input[:,self.r*x+y::self.r*self.r,:,:])
return out
# Subpixel Upsample Layer using Reshapes
class SubpixelLayer2(lasagne.layers.Layer):
def __init__(self, incoming,r,c, **kwargs):
super(SubpixelLayer2, self).__init__(incoming, **kwargs)
self.r=r
self.c=c
def get_output_shape_for(self, input_shape):
return (input_shape[0],self.c,self.r*input_shape[2],self.r*input_shape[3])
def get_output_for(self, input, deterministic=False, **kwargs):
def _phase_shift(input,r):
bsize,c,a,b = input.shape[0],1,self.output_shape[2]//r,self.output_shape[3]//r
X = T.reshape(input, (bsize,r,r,a,b))
X = T.transpose(X, (0, 3,4,1,2)) # bsize, a, b, r2,r1
X = T.split(x=X,splits_size=[1]*a,n_splits=a,axis=1) # a, [bsize, b, r, r]
X = [T.reshape(x,(bsize,b,r,r))for x in X]
X = T.concatenate(X,axis=2) # bsize, b, a*r, r
X = T.split(x=X,splits_size =[1]*b,n_splits=b,axis=1) # b, [bsize, a*r, r]
X = [T.reshape(x,(bsize,a*r,r))for x in X]
X = T.concatenate(X,axis=2) # bsize, a*r, b*r
return X.dimshuffle(0,'x',1,2)
Xc = T.split(x=input,splits_size =[input.shape[1]//self.c]*self.c,n_splits=self.c,axis=1)
return T.concatenate([_phase_shift(xc,self.r) for xc in Xc],axis=1)
# Subpixel Upsample Layer with inc_subtensor
# This layer uses a set of r^2 inc_subtensor calls to reorganize the tensor in a subpixel-layer upscaling style
# as done in the ESPCN magic pony paper for super-resolution. There is almost certainly a more efficient way to do this,
# but I haven't figured it out at the moment and this seems to be fast enough.
class SubpixelLayer3(lasagne.layers.Layer):
def __init__(self, incoming,r,c, **kwargs):
super(SubpixelLayer3, self).__init__(incoming, **kwargs)
self.r=r
self.c=c
def get_output_shape_for(self, input_shape):
return (input_shape[0],self.c,self.r*input_shape[2],self.r*input_shape[3])
def get_output_for(self, input, deterministic=False, **kwargs):
out = T.zeros((input.shape[0],self.output_shape[1],self.output_shape[2],self.output_shape[3]))
for x in xrange(self.r): # loop across all feature maps belonging to this channel
for y in xrange(self.r):
out=T.inc_subtensor(out[:,:,x::self.r,y::self.r],input[:,self.r*x+y::self.r*self.r,:,:])
return out
# Simple test
l_in = lasagne.layers.InputLayer(shape=(128,12,32,32))
s1 = SubpixelLayer(l_in,r=2,c=3)
s2 = SubpixelLayer2(l_in,r=2,c=3)
s3 = SubpixelLayer3(l_in,r=2,c=3)
X = T.TensorType('float32', [False]*4)('X')
fs1 = theano.function([X],lasagne.layers.get_output(s1,X,deterministic=True))
fs2 = theano.function([X],lasagne.layers.get_output(s2,X,deterministic=True))
fs3 = theano.function([X],lasagne.layers.get_output(s3,X,deterministic=True))
print('Testing subpixel layer 1...')
x = np.float32(np.random.randn(128,12,32,32))
s1_start = time.time()
n = 1000
for i in xrange(n):
q = fs1(x)
s1_end = time.time()-s1_start
print('Time for '+str(n)+' subpixel 1 calls is ' + str(s1_end) + ' seconds.')
print('Testing subpixel layer 2...')
s2_start = time.time()
for i in xrange(n):
q = fs2(x)
s2_end = time.time()-s2_start
print('Time for '+str(n)+' subpixel 2 calls is ' + str(s2_end) + ' seconds.')
print('Testing subpixel layer 3...')
s3_start = time.time()
for i in xrange(n):
q = fs3(x)
s3_end = time.time()-s3_start
print('Time for '+str(n)+' subpixel 3 calls is ' + str(s3_end) + ' seconds.')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment