Created
April 18, 2023 13:04
-
-
Save pythonlessons/9158d2a042c5877717e522f6c6038629 to your computer and use it in GitHub Desktop.
gan_introduction
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class ResultsCallback(tf.keras.callbacks.Callback): | |
| """A callback that saves generated images after each epoch.""" | |
| def __init__( | |
| self, | |
| noise_dim: int, | |
| results_path: str, | |
| examples_to_generate: int=16, | |
| grid_size: tuple=(4, 4), | |
| spacing: int=5, | |
| gif_size: tuple=(416, 416), | |
| duration: float=0.1 | |
| ): | |
| """ Initializes the ResultsCallback class. | |
| Args: | |
| noise_dim (int): The dimensionality of the noise vector that is inputted to the generator. | |
| results_path (str): The path to the directory where the results will be saved. | |
| examples_to_generate (int, optional): The number of images to generate and save. Defaults to 16. | |
| grid_size (tuple, optional): The size of the grid to arrange the generated images. Defaults to (4, 4). | |
| spacing (int, optional): The spacing between the generated images. Defaults to 5. | |
| gif_size (tuple, optional): The size of the gif to be generated. Defaults to (416, 416). | |
| duration (float, optional): The duration of each frame in the gif. Defaults to 0.1. | |
| """ | |
| super(ResultsCallback, self).__init__() | |
| self.seed = tf.random.normal([examples_to_generate, noise_dim]) | |
| self.results = [] | |
| self.results_path = results_path + '/results' | |
| self.grid_size = grid_size | |
| self.spacing = spacing | |
| self.gif_size = gif_size | |
| self.duration = duration | |
| # create the results directory if it doesn't exist | |
| os.makedirs(self.results_path, exist_ok=True) | |
| def save_pred(self, epoch: int, results: list) -> None: | |
| """ Saves the generated images as a grid and as a gif. | |
| Args: | |
| epoch (int): The current epoch. | |
| results (list): A list of generated images. | |
| """ | |
| # construct an image from generated images with spacing between them using numpy | |
| w, h , c = results[0].shape | |
| # construct grid with self.grid_size | |
| grid = np.zeros((self.grid_size[0] * w + (self.grid_size[0] - 1) * self.spacing, self.grid_size[1] * h + (self.grid_size[1] - 1) * self.spacing, c), dtype=np.uint8) | |
| for i in range(self.grid_size[0]): | |
| for j in range(self.grid_size[1]): | |
| grid[i * (w + self.spacing):i * (w + self.spacing) + w, j * (h + self.spacing):j * (h + self.spacing) + h] = results[i * self.grid_size[1] + j] | |
| # save the image | |
| cv2.imwrite(f'{self.results_path}/img_{epoch}.png', grid) | |
| # save image to memory resized to gif size | |
| self.results.append(cv2.resize(grid, self.gif_size, interpolation=cv2.INTER_AREA)) | |
| def on_epoch_end(self, epoch: int, logs: dict=None) -> None: | |
| """Executes at the end of each epoch.""" | |
| predictions = self.model.generator(self.seed, training=False) | |
| predictions_uint8 = (predictions * 127.5 + 127.5).numpy().astype(np.uint8) | |
| self.save_pred(epoch, predictions_uint8) | |
| def on_train_end(self, logs=None) -> None: | |
| """Executes at the end of training.""" | |
| # save the results as a gif with imageio | |
| # Create a list of imageio image objects from the OpenCV images | |
| imageio_images = [imageio.core.util.Image(image) for image in self.results] | |
| # Write the imageio images to a GIF file | |
| imageio.mimsave(self.results_path + "/output.gif", imageio_images, duration=self.duration) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment