Last active
June 28, 2017 05:04
-
-
Save daviddao/c05cc5913c815c1e2154 to your computer and use it in GitHub Desktop.
Tensorflow Spatial Transformer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwargs): | |
"""Spatial Transformer Layer | |
Implements a spatial transformer layer as described in [1]_. | |
Based on [2]_ and edited by David Dao for Tensorflow. | |
Parameters | |
---------- | |
U : float | |
The output of a convolutional net should have the | |
shape [num_batch, height, width, num_channels]. | |
theta: float | |
The output of the | |
localisation network should be [num_batch, 6]. | |
downsample_factor : float | |
A value of 1 will keep the original size of the image | |
Values larger than 1 will downsample the image. | |
Values below 1 will upsample the image | |
example image: height = 100, width = 200 | |
downsample_factor = 2 | |
output image will then be 50, 100 | |
References | |
---------- | |
.. [1] Spatial Transformer Networks | |
Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu | |
Submitted on 5 Jun 2015 | |
.. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py | |
Notes | |
----- | |
To initialize the network to the identity transform init | |
``theta`` to : | |
identity = np.array([[1., 0., 0.], | |
[0., 1., 0.]]) | |
identity = identity.flatten() | |
theta = tf.Variable(initial_value=identity) | |
""" | |
def _repeat(x, n_repeats): | |
with tf.variable_scope('_repeat'): | |
rep = tf.transpose(tf.expand_dims(tf.ones(shape=tf.pack([n_repeats,])),1),[1,0]) | |
rep = tf.cast(rep, 'int32') | |
x = tf.matmul(tf.reshape(x,(-1, 1)), rep) | |
return tf.reshape(x,[-1]) | |
def _interpolate(im, x, y, downsample_factor): | |
with tf.variable_scope('_interpolate'): | |
# constants | |
num_batch = tf.shape(im)[0] | |
height = tf.shape(im)[1] | |
width = tf.shape(im)[2] | |
channels = tf.shape(im)[3] | |
x = tf.cast(x, 'float32') | |
y = tf.cast(y, 'float32') | |
height_f = tf.cast(height, 'float32') | |
width_f = tf.cast(width, 'float32') | |
out_height = tf.cast(height_f // downsample_factor, 'int32') | |
out_width = tf.cast(width_f // downsample_factor, 'int32') | |
zero = tf.zeros([], dtype='int32') | |
max_y = tf.cast(tf.shape(im)[1] - 1, 'int32') | |
max_x = tf.cast(tf.shape(im)[2] - 1, 'int32') | |
# scale indices from [-1, 1] to [0, width/height] | |
x = (x + 1.0)*(width_f) / 2.0 | |
y = (y + 1.0)*(height_f) / 2.0 | |
# do sampling | |
x0 = tf.cast(tf.floor(x), 'int32') | |
x1 = x0 + 1 | |
y0 = tf.cast(tf.floor(y), 'int32') | |
y1 = y0 + 1 | |
x0 = tf.clip_by_value(x0, zero, max_x) | |
x1 = tf.clip_by_value(x1, zero, max_x) | |
y0 = tf.clip_by_value(y0, zero, max_y) | |
y1 = tf.clip_by_value(y1, zero, max_y) | |
dim2 = width | |
dim1 = width*height | |
base = _repeat(tf.range(num_batch)*dim1, out_height*out_width) | |
base_y0 = base + y0*dim2 | |
base_y1 = base + y1*dim2 | |
idx_a = base_y0 + x0 | |
idx_b = base_y1 + x0 | |
idx_c = base_y0 + x1 | |
idx_d = base_y1 + x1 | |
# use indices to lookup pixels in the flat image and restore channels dim | |
im_flat = tf.reshape(im,tf.pack([-1, channels])) | |
im_flat = tf.cast(im_flat, 'float32') | |
Ia = tf.gather(im_flat, idx_a) | |
Ib = tf.gather(im_flat, idx_b) | |
Ic = tf.gather(im_flat, idx_c) | |
Id = tf.gather(im_flat, idx_d) | |
# and finally calculate interpolated values | |
x0_f = tf.cast(x0, 'float32') | |
x1_f = tf.cast(x1, 'float32') | |
y0_f = tf.cast(y0, 'float32') | |
y1_f = tf.cast(y1, 'float32') | |
wa = tf.expand_dims(((x1_f-x) * (y1_f-y)),1) | |
wb = tf.expand_dims(((x1_f-x) * (y-y0_f)),1) | |
wc = tf.expand_dims(((x-x0_f) * (y1_f-y)),1) | |
wd = tf.expand_dims(((x-x0_f) * (y-y0_f)),1) | |
output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id]) | |
return output | |
def _meshgrid(height, width): | |
with tf.variable_scope('_meshgrid'): | |
# This should be equivalent to: | |
# x_t, y_t = np.meshgrid(np.linspace(-1, 1, width), | |
# np.linspace(-1, 1, height)) | |
# ones = np.ones(np.prod(x_t.shape)) | |
# grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) | |
x_t = tf.matmul(tf.ones(shape=tf.pack([height, 1])), | |
tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width),1),[1,0])) | |
y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height),1), | |
tf.ones(shape=tf.pack([1, width]))) | |
x_t_flat = tf.reshape(x_t,(1, -1)) | |
y_t_flat = tf.reshape(y_t,(1, -1)) | |
ones = tf.ones_like(x_t_flat) | |
grid = tf.concat(0, [x_t_flat, y_t_flat, ones]) | |
return grid | |
def _transform(theta, input_dim, downsample_factor): | |
with tf.variable_scope('_transform'): | |
num_batch = tf.shape(input_dim)[0] | |
height = tf.shape(input_dim)[1] | |
width = tf.shape(input_dim)[2] | |
num_channels = tf.shape(input_dim)[3] | |
theta = tf.reshape(theta, (-1, 2, 3)) | |
theta = tf.cast(theta, 'float32') | |
# grid of (x_t, y_t, 1), eq (1) in ref [1] | |
height_f = tf.cast(height, 'float32') | |
width_f = tf.cast(width, 'float32') | |
out_height = tf.cast(height_f // downsample_factor, 'int32') | |
out_width = tf.cast(width_f // downsample_factor, 'int32') | |
grid = _meshgrid(out_height, out_width) | |
grid = tf.expand_dims(grid,0) | |
grid = tf.reshape(grid,[-1]) | |
grid = tf.tile(grid,tf.pack([num_batch])) | |
grid = tf.reshape(grid,tf.pack([num_batch, 3, -1])) | |
# Transform A x (x_t, y_t, 1)^T -> (x_s, y_s) | |
T_g = tf.batch_matmul(theta, grid) | |
x_s = tf.slice(T_g, [0,0,0], [-1,1,-1]) | |
y_s = tf.slice(T_g, [0,1,0], [-1,1,-1]) | |
x_s_flat = tf.reshape(x_s,[-1]) | |
y_s_flat = tf.reshape(y_s,[-1]) | |
input_transformed = _interpolate( | |
input_dim, x_s_flat, y_s_flat, | |
downsample_factor) | |
output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, num_channels])) | |
return output | |
with tf.variable_scope(name): | |
output = _transform(theta, U, downsample_factor) | |
return output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from spatial_transformer import transformer | |
from scipy import ndimage | |
import numpy as np | |
import matplotlib.pyplot as plt | |
def conv2d(x, n_filters, | |
k_h=5, k_w=5, | |
stride_h=2, stride_w=2, | |
stddev=0.02, | |
activation=lambda x: x, | |
bias=True, | |
padding='SAME', | |
name="Conv2D"): | |
"""2D Convolution with options for kernel size, stride, and init deviation. | |
Parameters | |
---------- | |
x : Tensor | |
Input tensor to convolve. | |
n_filters : int | |
Number of filters to apply. | |
k_h : int, optional | |
Kernel height. | |
k_w : int, optional | |
Kernel width. | |
stride_h : int, optional | |
Stride in rows. | |
stride_w : int, optional | |
Stride in cols. | |
stddev : float, optional | |
Initialization's standard deviation. | |
activation : arguments, optional | |
Function which applies a nonlinearity | |
padding : str, optional | |
'SAME' or 'VALID' | |
name : str, optional | |
Variable scope to use. | |
Returns | |
------- | |
x : Tensor | |
Convolved input. | |
""" | |
with tf.variable_scope(name): | |
w = tf.get_variable( | |
'w', [k_h, k_w, x.get_shape()[-1], n_filters], | |
initializer=tf.truncated_normal_initializer(stddev=stddev)) | |
conv = tf.nn.conv2d( | |
x, w, strides=[1, stride_h, stride_w, 1], padding=padding) | |
if bias: | |
b = tf.get_variable( | |
'b', [n_filters], | |
initializer=tf.truncated_normal_initializer(stddev=stddev)) | |
conv = conv + b | |
return conv | |
def linear(x, n_units, scope=None, stddev=0.02, | |
activation=lambda x: x): | |
"""Fully-connected network. | |
Parameters | |
---------- | |
x : Tensor | |
Input tensor to the network. | |
n_units : int | |
Number of units to connect to. | |
scope : str, optional | |
Variable scope to use. | |
stddev : float, optional | |
Initialization's standard deviation. | |
activation : arguments, optional | |
Function which applies a nonlinearity | |
Returns | |
------- | |
x : Tensor | |
Fully-connected output. | |
""" | |
shape = x.get_shape().as_list() | |
with tf.variable_scope(scope or "Linear"): | |
matrix = tf.get_variable("Matrix", [shape[1], n_units], tf.float32, | |
tf.random_normal_initializer(stddev=stddev)) | |
return activation(tf.matmul(x, matrix)) | |
# %% | |
def weight_variable(shape): | |
'''Helper function to create a weight variable initialized with | |
a normal distribution | |
Parameters | |
---------- | |
shape : list | |
Size of weight variable | |
''' | |
#initial = tf.random_normal(shape, mean=0.0, stddev=0.01) | |
initial = tf.zeros(shape) | |
return tf.Variable(initial) | |
# %% | |
def bias_variable(shape): | |
'''Helper function to create a bias variable initialized with | |
a constant value. | |
Parameters | |
---------- | |
shape : list | |
Size of weight variable | |
''' | |
initial = tf.random_normal(shape, mean=0.0, stddev=0.01) | |
return tf.Variable(initial) | |
# Preprocessing | |
# Create a batch of three images (1600 x 1200) | |
im = ndimage.imread('cat.jpg') | |
im = im / 255. | |
im = im.reshape(1, 1200, 1600, 3) | |
im = im.astype('float32') | |
# Simulate batch | |
batch = np.append(im, im, axis=0) | |
batch = np.append(batch, im, axis=0) | |
num_batch = 3 | |
x = tf.placeholder(tf.float32, [None, 1200, 1600, 3]) | |
x = tf.cast(batch,'float32') | |
num_batch = 3 | |
x = tf.placeholder(tf.float32, [None, 1200, 1600, 3]) | |
x = tf.cast(batch,'float32') | |
# Create localisation network and convolutional layer | |
with tf.variable_scope('spatial_transformer_0'): | |
# filter_size = 3 | |
# n_filters_1 = 3 | |
# W_conv1 = weight_variable([filter_size, filter_size, 3, n_filters_1]) | |
# # %% Bias is [output_channels] | |
# b_conv1 = bias_variable([n_filters_1]) | |
# # %% Now we can build a graph which does the first layer of convolution: | |
# # we define our stride as batch x height x width x channels | |
# # instead of pooling, we use strides of 2 and more layers | |
# # with smaller filters. | |
# h_conv1 = tf.nn.relu( | |
# tf.nn.conv2d(input=x, | |
# filter=W_conv1, | |
# strides=[1, 1, 1, 1], | |
# padding='SAME') + | |
# b_conv1) | |
# h_conv1_trans = tf.transpose(h_conv1, perm=[0, 3, 1, 2]) | |
# # %% We'll now reshape so we can connect to a fully-connected layer: | |
# h_conv2_flat = tf.reshape(h_conv1, [-1, 1200 * 1600 * 3]) | |
# %% Create a fully-connected layer: | |
n_fc = 6 | |
W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1') | |
initial = np.array([[0.5,0, 0],[0,0.5,0]]) | |
initial = initial.flatten() | |
b_fc1 = tf.Variable(initial_value=initial, name='b_fc1') | |
b_fc1 = tf.cast(b_fc1, 'float32') # cast it to float32 | |
x_flatten = tf.reshape(x,[-1,1200 * 1600 * 3]) | |
#h_fc1 = tf.nn.relu(tf.matmul(x_flatten, W_fc1) + b_fc1) | |
h_fc1 = tf.matmul(tf.zeros([num_batch ,1200 * 1600 * 3]), W_fc1) + b_fc1 | |
h_trans = transformer(x, h_fc1, downsample_factor=2) | |
# Run session | |
sess = tf.Session() | |
sess.run(tf.initialize_all_variables()) | |
y = sess.run(h_trans, feed_dict={x: batch}) | |
plt.imshow(y[0]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment