Skip to content

Instantly share code, notes, and snippets.

@KyanainsGate
Last active June 3, 2019 06:10
Show Gist options
  • Save KyanainsGate/ed952c3e7c4e9c26e828ecdac92bb04e to your computer and use it in GitHub Desktop.
Save KyanainsGate/ed952c3e7c4e9c26e828ecdac92bb04e to your computer and use it in GitHub Desktop.
softargmax_tensorflow
def tf_softargmax(x, batch ,min=0., max=1.):
"""
"soft-argmax" is a differentiable version of "argmax" function [Kendall et al. 2017].
URL: https://arxiv.org/abs/1703.04309
First computes the softmax over the spatial extent of each channel of a convolutional feature map.
Then computes the expected 2D position of the points of maximal activation for each channel,
resulting in a set of feature keypoints (e.g. [[E[x1], E[y1]], [E[x2], E[y2]], ... ,E[xN], E[yN]])
in each channel and batch.
:param x: A tensor, shape=(Batch, Height, Width, Channel)
:param b: A tensor to feed the batch number, dtype=tf.int32, shape=().
That is used to clone and tile meshgrid tensors according to the number of batch.
:param min: A value corresponding to the Index[0]
:param max: A value corresponding to the Index[(Length one side of the image)]
:return: A tensor, shape=(Batch, Channel, 2).
"""
# cal softmax
_, h, w, ch = x.shape[0], x.shape[1], x.shape[2], x.shape[3] # tmp.shape =(b, h, w, ch)
meshgridRange = [min, max]
tmp = tf.transpose(x, perm=(0, 3, 1, 2)) # tmp.shape =(b, ch, h, w)
features = tf.reshape(tmp, shape=(-1, h * w))
softmax = tf.nn.softmax(features) # (b*ch, h*w)
softmax = tf.reshape(softmax, [-1, h, w]) # (b*ch, h, w)
# create meshgrid
posx, posy = tf.meshgrid(tf.lin_space(meshgridRange[0], meshgridRange[1], num=w),
tf.lin_space(meshgridRange[0], meshgridRange[1], num=h), indexing='ij'
) # ( h, w)
pos_x_expanded = tf.expand_dims(posx, axis=0) # (1, h, w)
pos_x_expanded = tf.tile(pos_x_expanded, [batch, 1, 1]) # (1, h, w)
pos_y_expanded = tf.expand_dims(posy, axis=0) # (1, h, w)
pos_y_expanded = tf.tile(pos_y_expanded, [batch, 1, 1]) # (1, h, w)
if pos_x_expanded.dtype == tf.float64:
pos_x_expanded = tf.cast(pos_x_expanded, tf.float32)
pos_y_expanded = tf.cast(pos_y_expanded, tf.float32)
expected_x = tf.reduce_sum(tf.multiply(pos_x_expanded, softmax), axis=[1, 2], keepdims=True) # (b*ch, )
expected_x = tf.reshape(expected_x, shape=(batch, ch, 1)) # (b, ch, 1)
expected_y = tf.reduce_sum(tf.multiply(pos_y_expanded, softmax), axis=[1, 2], keepdims=True)
expected_y = tf.reshape(expected_y, shape=(batch, ch, 1)) # (b, ch, 1)
out = tf.concat([expected_x, expected_y], axis=2) # (b, ch, 2), 2 means E[x] and E[y]
return out
# Example to use
if __name__ == '__main__':
arr1_old = np.random.randint(0, 100, (4, 4)) * 1.0
arr1 = arr1_old[:, :, np.newaxis]
arr2_old = np.random.randint(0, 100, (4, 4)) * 1.0
arr2 = arr2_old[:, :, np.newaxis]
print("array example1\n", arr1_old)
print("array example2\n", arr2_old)
arg_ndarray = np.array([arr1, arr2])
batch_num = len(arg_ndarray)
plh = tf.placeholder(dtype=tf.float32, shape=(None, 4, 4, 1),)
b_plh = tf.placeholder(dtype=tf.int32, shape=())
p = tf_softargmax(plh, b_plh, min=0., max=3.)
init = tf.global_variables_initializer()
# Launch the graph.
sess = tf.Session()
sess.run(init)
print(sess.run(p, feed_dict={plh:arg_ndarray, b_plh:batch_num}))
# return [ [ [E[x1_ch1], E[y1_ch1]], [E[x2_ch1], E[y2_ch1]] ] ]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment