Skip to content

Instantly share code, notes, and snippets.

@ajbrock
Created July 28, 2016 20:44
Show Gist options
  • Save ajbrock/a3858c26282d9731191901b397b3ce9f to your computer and use it in GitHub Desktop.
Save ajbrock/a3858c26282d9731191901b397b3ce9f to your computer and use it in GitHub Desktop.
def reflect_pad(x, width, batch_ndim=1):
"""
Pad a tensor with a constant value.
Parameters
----------
x : tensor
width : int, iterable of int, or iterable of tuple
Padding width. If an int, pads each axis symmetrically with the same
amount in the beginning and end. If an iterable of int, defines the
symmetric padding width separately for each axis. If an iterable of
tuples of two ints, defines a seperate padding width for each beginning
and end of each axis.
batch_ndim : integer
Dimensions before the value will not be padded.
"""
# Idea for how to make this happen: Flip the tensor horizontally to grab horizontal values, then vertically to grab vertical values
# alternatively, just slice correctly
input_shape = x.shape
input_ndim = x.ndim
output_shape = list(input_shape)
indices = [slice(None) for _ in output_shape]
if isinstance(width, int):
widths = [width] * (input_ndim - batch_ndim)
else:
widths = width
for k, w in enumerate(widths):
try:
l, r = w
except TypeError:
l = r = w
output_shape[k + batch_ndim] += l + r
indices[k + batch_ndim] = slice(l, l + input_shape[k + batch_ndim])
# Create output array
out = T.zeros(output_shape)
# Vertical Reflections
out=T.set_subtensor(out[:,:,:width,width:-width], x[:,:,width:0:-1,:])# out[:,:,:width,width:-width] = x[:,:,width:0:-1,:]
out=T.set_subtensor(out[:,:,-width:,width:-width], x[:,:,-2:-(2+width):-1,:])#out[:,:,-width:,width:-width] = x[:,:,-2:-(2+width):-1,:]
# Place X in out
# out = T.set_subtensor(out[tuple(indices)], x) # or, alternative, out[width:-width,width:-width] = x
out=T.set_subtensor(out[:,:,width:-width,width:-width],x)#out[:,:,width:-width,width:-width] = x
#Horizontal reflections
out=T.set_subtensor(out[:,:,:,:width],out[:,:,:,(2*width):width:-1])#out[:,:,:,:width] = out[:,:,:,(2*width):width:-1]
out=T.set_subtensor(out[:,:,:,-width:],out[:,:,:,-(width+2):-(2*width+2):-1])#out[:,:,:,-width:] = out[:,:,:,-(width+2):-(2*width+2):-1]
return out
class ReflectLayer(lasagne.layers.Layer):
def __init__(self, incoming, width, batch_ndim=2, **kwargs):
super(ReflectLayer, self).__init__(incoming, **kwargs)
self.width = width
self.batch_ndim = batch_ndim
def get_output_shape_for(self, input_shape):
output_shape = list(input_shape)
if isinstance(self.width, int):
widths = [self.width] * (len(input_shape) - self.batch_ndim)
else:
widths = self.width
for k, w in enumerate(widths):
if output_shape[k + self.batch_ndim] is None:
continue
else:
try:
l, r = w
except TypeError:
l = r = w
output_shape[k + self.batch_ndim] += l + r
return tuple(output_shape)
def get_output_for(self, input, **kwargs):
return reflect_pad(input, self.width, self.batch_ndim)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment