Skip to content

Instantly share code, notes, and snippets.

@KyanainsGate
Last active June 10, 2019 08:25
Show Gist options
  • Save KyanainsGate/7871532038395ee344a16faac4fd024d to your computer and use it in GitHub Desktop.
Save KyanainsGate/7871532038395ee344a16faac4fd024d to your computer and use it in GitHub Desktop.
def broadcast_1d_to_2dtile(x, h_and_w_list, name=None):
"""
Convert (Batch, dim) to the image like tensor, as (Batch, h, w, dim)
That function is used in "Neural Representation and rendering (S. M. Ali Eslami+ 2018 Science)"
to feed CNN "v" vector.
e.g.) x=[[1,2,3]] and batch_size=1 h_and_w_list=[2,2]
=> [[[1,1],[1,1]],[[2,2],[2,2]],[[3,3],[3,3]]]
:param x: A tensor, (Batch, dim)
:param h_and_w_list:A list of output image, shape is [h, w]
:param name:
:return:A tensor, (Batch, h, w, dim)
"""
scope_name = name
if name is None:
scope_name = 'broadcast_to_2d'
with tf.name_scope(scope_name):
h, w = h_and_w_list
dim = x.shape[1]
x = tf.transpose(tf.expand_dims(x, axis=0), perm=(1, 2, 0))
# (1, batch, dim) => (batch, dims, 1), 1 means "after" shape
x = tf.tile(x, [1, 1, h * w]) # (batch, dim, h*w)
x = tf.reshape(x, [-1, dim, h, w]) # (batch, dim, h, w)
x = tf.transpose(x, perm=(0, 2, 3, 1)) # (batch, h, w, dim{=ch})
return x
# Example to use
if __name__ == '__main__':
# ndarray
arr3 = np.random.randint(0, 100, (2, 4)) * 1.0
batch_num = len(arr3)
print(batch_num)
print("array example2\n", arr3)
# placeholder and define session
plh = tf.placeholder(dtype=tf.float32, shape=(None, 4),)
bro = broadcast_1d_to_2dtile(plh, [8, 8])
init = tf.global_variables_initializer()
# Launch the graph.
sess = tf.Session()
sess.run(init)
bcm_ = sess.run(bro, feed_dict={plh:arr3})
print('_____________')
print(bcm_.shape)
print(bcm_[0, :, :, 0])
print(bcm_[0, :, :, 1])
print(bcm_[0, :, :, 3])
print(bcm_[1, :, :, 2])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment