Created
May 30, 2023 08:45
-
-
Save pythonlessons/6f7c1399989f172b6618406fbfb68595 to your computer and use it in GitHub Desktop.
wgan_gp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class ResultsCallback(tf.keras.callbacks.Callback): | |
| """ Callback for generating and saving images during training.""" | |
| def __init__( | |
| self, | |
| noise_dim: int, | |
| output_path: str, | |
| examples_to_generate: int=16, | |
| grid_size: tuple=(4, 4), | |
| spacing: int=5, | |
| gif_size: tuple=(416, 416), | |
| duration: float=0.1, | |
| save_model: bool=True | |
| ) -> None: | |
| super(ResultsCallback, self).__init__() | |
| self.seed = tf.random.normal([examples_to_generate, noise_dim]) | |
| self.results = [] | |
| self.output_path = output_path | |
| self.results_path = output_path + '/results' | |
| self.grid_size = grid_size | |
| self.spacing = spacing | |
| self.gif_size = gif_size | |
| self.duration = duration | |
| self.save_model = save_model | |
| os.makedirs(self.results_path, exist_ok=True) | |
| def save_plt(self, epoch: int, results: np.ndarray): | |
| # construct an image from generated images with spacing between them using numpy | |
| w, h , c = results[0].shape | |
| # construct grind with self.grid_size | |
| grid = np.zeros((self.grid_size[0] * w + (self.grid_size[0] - 1) * self.spacing, self.grid_size[1] * h + (self.grid_size[1] - 1) * self.spacing, c), dtype=np.uint8) | |
| for i in range(self.grid_size[0]): | |
| for j in range(self.grid_size[1]): | |
| grid[i * (w + self.spacing):i * (w + self.spacing) + w, j * (h + self.spacing):j * (h + self.spacing) + h] = results[i * self.grid_size[1] + j] | |
| grid = cv2.cvtColor(grid, cv2.COLOR_RGB2BGR) | |
| # save the image | |
| cv2.imwrite(f'{self.results_path}/img_{epoch}.png', grid) | |
| # save image to memory resized to gif size | |
| self.results.append(cv2.resize(grid, self.gif_size, interpolation=cv2.INTER_AREA)) | |
| def on_epoch_end(self, epoch: int, logs: dict=None): | |
| # Define your custom code here that should be executed at the end of each epoch | |
| predictions = self.model.generator(self.seed, training=False) | |
| predictions_uint8 = (predictions * 127.5 + 127.5).numpy().astype(np.uint8) | |
| self.save_plt(epoch, predictions_uint8) | |
| if self.save_model: | |
| # save keras model to disk | |
| models_path = os.path.join(self.output_path, "model") | |
| os.makedirs(models_path, exist_ok=True) | |
| self.model.discriminator.save(models_path + "/discriminator.h5") | |
| self.model.generator.save(models_path + "/generator.h5") | |
| def on_train_end(self, logs: dict=None): | |
| # save the results as a gif with imageio | |
| # Create a list of imageio image objects from the OpenCV images | |
| # image is in BGR format, convert to RGB format when loading | |
| imageio_images = [imageio.core.util.Image(image[...,::-1]) for image in self.results] | |
| # Write the imageio images to a GIF file | |
| imageio.mimsave(self.results_path + "/output.gif", imageio_images, duration=self.duration) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment