Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created May 30, 2023 08:45
Show Gist options
  • Select an option

  • Save pythonlessons/6f7c1399989f172b6618406fbfb68595 to your computer and use it in GitHub Desktop.

Select an option

Save pythonlessons/6f7c1399989f172b6618406fbfb68595 to your computer and use it in GitHub Desktop.
wgan_gp
class ResultsCallback(tf.keras.callbacks.Callback):
""" Callback for generating and saving images during training."""
def __init__(
self,
noise_dim: int,
output_path: str,
examples_to_generate: int=16,
grid_size: tuple=(4, 4),
spacing: int=5,
gif_size: tuple=(416, 416),
duration: float=0.1,
save_model: bool=True
) -> None:
super(ResultsCallback, self).__init__()
self.seed = tf.random.normal([examples_to_generate, noise_dim])
self.results = []
self.output_path = output_path
self.results_path = output_path + '/results'
self.grid_size = grid_size
self.spacing = spacing
self.gif_size = gif_size
self.duration = duration
self.save_model = save_model
os.makedirs(self.results_path, exist_ok=True)
def save_plt(self, epoch: int, results: np.ndarray):
# construct an image from generated images with spacing between them using numpy
w, h , c = results[0].shape
# construct grind with self.grid_size
grid = np.zeros((self.grid_size[0] * w + (self.grid_size[0] - 1) * self.spacing, self.grid_size[1] * h + (self.grid_size[1] - 1) * self.spacing, c), dtype=np.uint8)
for i in range(self.grid_size[0]):
for j in range(self.grid_size[1]):
grid[i * (w + self.spacing):i * (w + self.spacing) + w, j * (h + self.spacing):j * (h + self.spacing) + h] = results[i * self.grid_size[1] + j]
grid = cv2.cvtColor(grid, cv2.COLOR_RGB2BGR)
# save the image
cv2.imwrite(f'{self.results_path}/img_{epoch}.png', grid)
# save image to memory resized to gif size
self.results.append(cv2.resize(grid, self.gif_size, interpolation=cv2.INTER_AREA))
def on_epoch_end(self, epoch: int, logs: dict=None):
# Define your custom code here that should be executed at the end of each epoch
predictions = self.model.generator(self.seed, training=False)
predictions_uint8 = (predictions * 127.5 + 127.5).numpy().astype(np.uint8)
self.save_plt(epoch, predictions_uint8)
if self.save_model:
# save keras model to disk
models_path = os.path.join(self.output_path, "model")
os.makedirs(models_path, exist_ok=True)
self.model.discriminator.save(models_path + "/discriminator.h5")
self.model.generator.save(models_path + "/generator.h5")
def on_train_end(self, logs: dict=None):
# save the results as a gif with imageio
# Create a list of imageio image objects from the OpenCV images
# image is in BGR format, convert to RGB format when loading
imageio_images = [imageio.core.util.Image(image[...,::-1]) for image in self.results]
# Write the imageio images to a GIF file
imageio.mimsave(self.results_path + "/output.gif", imageio_images, duration=self.duration)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment