Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created April 18, 2023 13:04
Show Gist options
  • Select an option

  • Save pythonlessons/9158d2a042c5877717e522f6c6038629 to your computer and use it in GitHub Desktop.

Select an option

Save pythonlessons/9158d2a042c5877717e522f6c6038629 to your computer and use it in GitHub Desktop.
gan_introduction
class ResultsCallback(tf.keras.callbacks.Callback):
"""A callback that saves generated images after each epoch."""
def __init__(
self,
noise_dim: int,
results_path: str,
examples_to_generate: int=16,
grid_size: tuple=(4, 4),
spacing: int=5,
gif_size: tuple=(416, 416),
duration: float=0.1
):
""" Initializes the ResultsCallback class.
Args:
noise_dim (int): The dimensionality of the noise vector that is inputted to the generator.
results_path (str): The path to the directory where the results will be saved.
examples_to_generate (int, optional): The number of images to generate and save. Defaults to 16.
grid_size (tuple, optional): The size of the grid to arrange the generated images. Defaults to (4, 4).
spacing (int, optional): The spacing between the generated images. Defaults to 5.
gif_size (tuple, optional): The size of the gif to be generated. Defaults to (416, 416).
duration (float, optional): The duration of each frame in the gif. Defaults to 0.1.
"""
super(ResultsCallback, self).__init__()
self.seed = tf.random.normal([examples_to_generate, noise_dim])
self.results = []
self.results_path = results_path + '/results'
self.grid_size = grid_size
self.spacing = spacing
self.gif_size = gif_size
self.duration = duration
# create the results directory if it doesn't exist
os.makedirs(self.results_path, exist_ok=True)
def save_pred(self, epoch: int, results: list) -> None:
""" Saves the generated images as a grid and as a gif.
Args:
epoch (int): The current epoch.
results (list): A list of generated images.
"""
# construct an image from generated images with spacing between them using numpy
w, h , c = results[0].shape
# construct grid with self.grid_size
grid = np.zeros((self.grid_size[0] * w + (self.grid_size[0] - 1) * self.spacing, self.grid_size[1] * h + (self.grid_size[1] - 1) * self.spacing, c), dtype=np.uint8)
for i in range(self.grid_size[0]):
for j in range(self.grid_size[1]):
grid[i * (w + self.spacing):i * (w + self.spacing) + w, j * (h + self.spacing):j * (h + self.spacing) + h] = results[i * self.grid_size[1] + j]
# save the image
cv2.imwrite(f'{self.results_path}/img_{epoch}.png', grid)
# save image to memory resized to gif size
self.results.append(cv2.resize(grid, self.gif_size, interpolation=cv2.INTER_AREA))
def on_epoch_end(self, epoch: int, logs: dict=None) -> None:
"""Executes at the end of each epoch."""
predictions = self.model.generator(self.seed, training=False)
predictions_uint8 = (predictions * 127.5 + 127.5).numpy().astype(np.uint8)
self.save_pred(epoch, predictions_uint8)
def on_train_end(self, logs=None) -> None:
"""Executes at the end of training."""
# save the results as a gif with imageio
# Create a list of imageio image objects from the OpenCV images
imageio_images = [imageio.core.util.Image(image) for image in self.results]
# Write the imageio images to a GIF file
imageio.mimsave(self.results_path + "/output.gif", imageio_images, duration=self.duration)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment