Created
June 5, 2019 17:07
-
-
Save tmwatchanan/b71791190c9a648622581debf31c38a1 to your computer and use it in GitHub Desktop.
A workaround for the currently unsupported operation, UpSampling2D, on TPU
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def nearest_upsampling(data, scale): | |
| """Nearest neighbor upsampling implementation. | |
| Args: | |
| data: A float32 tensor of size [batch, height_in, width_in, channels]. | |
| scale: An integer multiple to scale resolution of input data. | |
| Returns: | |
| data_up: A float32 tensor of size | |
| [batch, height_in*scale, width_in*scale, channels]. | |
| """ | |
| with tf.name_scope('nearest_upsampling'): | |
| bs, h, w, c = data.get_shape().as_list() | |
| bs = -1 if bs is None else bs | |
| # Use reshape to quickly upsample the input. The nearest pixel is selected | |
| # implicitly via broadcasting. | |
| data = tf.reshape(data, [bs, h, 1, w, 1, c]) * tf.ones( | |
| [1, 1, scale, 1, scale, 1], dtype=data.dtype) | |
| return tf.reshape(data, [bs, h * scale, w * scale, c]) | |
| # use Lambda layer along with the defined resize nearest neighbor upsampling function | |
| layer = Lambda(lambda x: nearest_upsampling(x, 2))(layer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment