Last active
November 23, 2021 16:06
-
-
Save zaccharieramzi/f7dd5f0e34691d0a987e1e50b694ac03 to your computer and use it in GitHub Desktop.
This keras callback outputs the results of an image-to-image model to tensorboard. In this case it was done for a denoiser, but it could also be implemented for segmentation, super-resolution, ...
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Inspired by https://stackoverflow.com/a/49363251/4332585""" | |
import io | |
from keras.callbacks import Callback | |
import numpy as np | |
from PIL import Image | |
from skimage.util import img_as_ubyte | |
import tensorflow as tf | |
def make_image(tensor): | |
""" | |
Convert an numpy representation image to Image protobuf. | |
Copied from https://github.com/lanpa/tensorboard-pytorch/ | |
""" | |
_, height, width, channel = tensor.shape | |
tensor = tensor[0] | |
tensor_normalized = tensor - tensor.min() | |
tensor_normalized /= tensor_normalized.max() | |
tensor_normalized = img_as_ubyte(tensor_normalized) | |
tensor_squeezed = np.squeeze(tensor_normalized) | |
image = Image.fromarray(tensor_squeezed) | |
output = io.BytesIO() | |
image.save(output, format='PNG') | |
image_string = output.getvalue() | |
output.close() | |
summary = tf.Summary.Image( | |
height=height, | |
width=width, | |
colorspace=channel, | |
encoded_image_string=image_string, | |
) | |
return summary | |
class TensorBoardImage(Callback): | |
def __init__(self, log_dir, image, noisy_image): | |
super().__init__() | |
self.log_dir = log_dir | |
self.image = image | |
self.noisy_image = noisy_image | |
def set_model(self, model): | |
self.model = model | |
self.writer = tf.summary.FileWriter(self.log_dir, filename_suffix='images') | |
def on_train_begin(self, _): | |
self.write_image(self.image, 'Original Image', 0) | |
def on_train_end(self, _): | |
self.writer.close() | |
def write_image(self, image, tag, epoch): | |
image = make_image(image) | |
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, image=image)]) | |
self.writer.add_summary(summary, epoch) | |
self.writer.flush() | |
def on_epoch_end(self, epoch, logs={}): | |
denoised_image = self.model.predict_on_batch(self.noisy_image) | |
self.write_image(denoised_image, 'Denoised Image', epoch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment