Last active
November 25, 2022 19:36
-
-
Save jimfleming/c1adfdb0f526465c99409cc143dea97b to your computer and use it in GitHub Desktop.
A utility function for TensorFlow that maps a grayscale image to a matplotlib colormap for use with TensorBoard image summaries.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib | |
import matplotlib.cm | |
import tensorflow as tf | |
def colorize(value, vmin=None, vmax=None, cmap=None): | |
""" | |
A utility function for TensorFlow that maps a grayscale image to a matplotlib | |
colormap for use with TensorBoard image summaries. | |
By default it will normalize the input value to the range 0..1 before mapping | |
to a grayscale colormap. | |
Arguments: | |
- value: 2D Tensor of shape [height, width] or 3D Tensor of shape | |
[height, width, 1]. | |
- vmin: the minimum value of the range used for normalization. | |
(Default: value minimum) | |
- vmax: the maximum value of the range used for normalization. | |
(Default: value maximum) | |
- cmap: a valid cmap named for use with matplotlib's `get_cmap`. | |
(Default: 'gray') | |
Example usage: | |
``` | |
output = tf.random_uniform(shape=[256, 256, 1]) | |
output_color = colorize(output, vmin=0.0, vmax=1.0, cmap='viridis') | |
tf.summary.image('output', output_color) | |
``` | |
Returns a 3D tensor of shape [height, width, 3]. | |
""" | |
# normalize | |
vmin = tf.reduce_min(value) if vmin is None else vmin | |
vmax = tf.reduce_max(value) if vmax is None else vmax | |
value = (value - vmin) / (vmax - vmin) # vmin..vmax | |
# squeeze last dim if it exists | |
value = tf.squeeze(value) | |
# quantize | |
indices = tf.to_int32(tf.round(value * 255)) | |
# gather | |
cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray') | |
colors = tf.constant(cm.colors, dtype=tf.float32) | |
value = tf.gather(colors, indices) | |
return value |
when you remove line 41 value = tf.squeeze(value)
it works on arbitrary input shapes and returns [original_shape, ] + [3]. Could be useful to write whole batches.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Very helpful, thanks alot. Came here from Pytorch. This works with numpy/torch arrays and exports a (h,w,4)- int8 instead.