Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created May 30, 2023 08:45
Show Gist options
  • Select an option

  • Save pythonlessons/846c9febd4392d7ef041ccb76e8bbd93 to your computer and use it in GitHub Desktop.

Select an option

Save pythonlessons/846c9febd4392d7ef041ccb76e8bbd93 to your computer and use it in GitHub Desktop.
wgan_gp
class LRSheduler(tf.keras.callbacks.Callback):
"""Learning rate scheduler for WGAN-GP"""
def __init__(self, decay_epochs: int, tb_callback=None, min_lr: float=0.00001):
super(LRSheduler, self).__init__()
self.decay_epochs = decay_epochs
self.min_lr = min_lr
self.tb_callback = tb_callback
self.compiled = False
def on_epoch_end(self, epoch, logs=None):
if not self.compiled:
self.generator_lr = self.model.generator_opt.lr.numpy()
self.discriminator_lr = self.model.discriminator_opt.lr.numpy()
self.compiled = True
if epoch < self.decay_epochs:
new_g_lr = max(self.generator_lr * (1 - (epoch / self.decay_epochs)), self.min_lr)
self.model.generator_opt.lr.assign(new_g_lr)
new_d_lr = max(self.discriminator_lr * (1 - (epoch / self.decay_epochs)), self.min_lr)
self.model.discriminator_opt.lr.assign(new_d_lr)
print(f"Learning rate generator: {new_g_lr}, discriminator: {new_d_lr}")
# Log the learning rate on TensorBoard
if self.tb_callback is not None:
writer = self.tb_callback._writers.get('train') # get the writer from the TensorBoard callback
with writer.as_default():
tf.summary.scalar('generator_lr', data=new_g_lr, step=epoch)
tf.summary.scalar('discriminator_lr', data=new_d_lr, step=epoch)
writer.flush()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment