Skip to content

Instantly share code, notes, and snippets.

@tmwatchanan
Created June 5, 2019 17:07
Show Gist options
  • Select an option

  • Save tmwatchanan/b71791190c9a648622581debf31c38a1 to your computer and use it in GitHub Desktop.

Select an option

Save tmwatchanan/b71791190c9a648622581debf31c38a1 to your computer and use it in GitHub Desktop.
A workaround for the currently unsupported operation, UpSampling2D, on TPU
def nearest_upsampling(data, scale):
"""Nearest neighbor upsampling implementation.
Args:
data: A float32 tensor of size [batch, height_in, width_in, channels].
scale: An integer multiple to scale resolution of input data.
Returns:
data_up: A float32 tensor of size
[batch, height_in*scale, width_in*scale, channels].
"""
with tf.name_scope('nearest_upsampling'):
bs, h, w, c = data.get_shape().as_list()
bs = -1 if bs is None else bs
# Use reshape to quickly upsample the input. The nearest pixel is selected
# implicitly via broadcasting.
data = tf.reshape(data, [bs, h, 1, w, 1, c]) * tf.ones(
[1, 1, scale, 1, scale, 1], dtype=data.dtype)
return tf.reshape(data, [bs, h * scale, w * scale, c])
# use Lambda layer along with the defined resize nearest neighbor upsampling function
layer = Lambda(lambda x: nearest_upsampling(x, 2))(layer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment