Last active
June 3, 2019 06:10
-
-
Save KyanainsGate/ed952c3e7c4e9c26e828ecdac92bb04e to your computer and use it in GitHub Desktop.
softargmax_tensorflow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def tf_softargmax(x, batch ,min=0., max=1.): | |
""" | |
"soft-argmax" is a differentiable version of "argmax" function [Kendall et al. 2017]. | |
URL: https://arxiv.org/abs/1703.04309 | |
First computes the softmax over the spatial extent of each channel of a convolutional feature map. | |
Then computes the expected 2D position of the points of maximal activation for each channel, | |
resulting in a set of feature keypoints (e.g. [[E[x1], E[y1]], [E[x2], E[y2]], ... ,E[xN], E[yN]]) | |
in each channel and batch. | |
:param x: A tensor, shape=(Batch, Height, Width, Channel) | |
:param b: A tensor to feed the batch number, dtype=tf.int32, shape=(). | |
That is used to clone and tile meshgrid tensors according to the number of batch. | |
:param min: A value corresponding to the Index[0] | |
:param max: A value corresponding to the Index[(Length one side of the image)] | |
:return: A tensor, shape=(Batch, Channel, 2). | |
""" | |
# cal softmax | |
_, h, w, ch = x.shape[0], x.shape[1], x.shape[2], x.shape[3] # tmp.shape =(b, h, w, ch) | |
meshgridRange = [min, max] | |
tmp = tf.transpose(x, perm=(0, 3, 1, 2)) # tmp.shape =(b, ch, h, w) | |
features = tf.reshape(tmp, shape=(-1, h * w)) | |
softmax = tf.nn.softmax(features) # (b*ch, h*w) | |
softmax = tf.reshape(softmax, [-1, h, w]) # (b*ch, h, w) | |
# create meshgrid | |
posx, posy = tf.meshgrid(tf.lin_space(meshgridRange[0], meshgridRange[1], num=w), | |
tf.lin_space(meshgridRange[0], meshgridRange[1], num=h), indexing='ij' | |
) # ( h, w) | |
pos_x_expanded = tf.expand_dims(posx, axis=0) # (1, h, w) | |
pos_x_expanded = tf.tile(pos_x_expanded, [batch, 1, 1]) # (1, h, w) | |
pos_y_expanded = tf.expand_dims(posy, axis=0) # (1, h, w) | |
pos_y_expanded = tf.tile(pos_y_expanded, [batch, 1, 1]) # (1, h, w) | |
if pos_x_expanded.dtype == tf.float64: | |
pos_x_expanded = tf.cast(pos_x_expanded, tf.float32) | |
pos_y_expanded = tf.cast(pos_y_expanded, tf.float32) | |
expected_x = tf.reduce_sum(tf.multiply(pos_x_expanded, softmax), axis=[1, 2], keepdims=True) # (b*ch, ) | |
expected_x = tf.reshape(expected_x, shape=(batch, ch, 1)) # (b, ch, 1) | |
expected_y = tf.reduce_sum(tf.multiply(pos_y_expanded, softmax), axis=[1, 2], keepdims=True) | |
expected_y = tf.reshape(expected_y, shape=(batch, ch, 1)) # (b, ch, 1) | |
out = tf.concat([expected_x, expected_y], axis=2) # (b, ch, 2), 2 means E[x] and E[y] | |
return out | |
# Example to use | |
if __name__ == '__main__': | |
arr1_old = np.random.randint(0, 100, (4, 4)) * 1.0 | |
arr1 = arr1_old[:, :, np.newaxis] | |
arr2_old = np.random.randint(0, 100, (4, 4)) * 1.0 | |
arr2 = arr2_old[:, :, np.newaxis] | |
print("array example1\n", arr1_old) | |
print("array example2\n", arr2_old) | |
arg_ndarray = np.array([arr1, arr2]) | |
batch_num = len(arg_ndarray) | |
plh = tf.placeholder(dtype=tf.float32, shape=(None, 4, 4, 1),) | |
b_plh = tf.placeholder(dtype=tf.int32, shape=()) | |
p = tf_softargmax(plh, b_plh, min=0., max=3.) | |
init = tf.global_variables_initializer() | |
# Launch the graph. | |
sess = tf.Session() | |
sess.run(init) | |
print(sess.run(p, feed_dict={plh:arg_ndarray, b_plh:batch_num})) | |
# return [ [ [E[x1_ch1], E[y1_ch1]], [E[x2_ch1], E[y2_ch1]] ] ] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment