Skip to content

Instantly share code, notes, and snippets.

@JossWhittle
Created June 30, 2020 12:02
Show Gist options
  • Save JossWhittle/0554f2c4c5082963b62b3f6a36917589 to your computer and use it in GitHub Desktop.
Save JossWhittle/0554f2c4c5082963b62b3f6a36917589 to your computer and use it in GitHub Desktop.
class LinearWarmUpAndCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, initial_learning_rate, warmup_steps, total_steps, alpha, name=None):
super(LinearWarmUpAndCosineDecay, self).__init__(name=name)
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.alpha = alpha
self.initial_learning_rate = initial_learning_rate
self.min_learning_rate = self.initial_learning_rate * self.alpha
self.cosine_decay_fn = tf.keras.experimental.CosineDecay(
self.initial_learning_rate, (self.total_steps - self.warmup_steps), alpha=self.alpha,
name='CosineDelayAfterLinearWarmup')
def linear_warmup(self, step):
return self.min_learning_rate + (step * ((self.init_learning_rate - self.min_learning_rate) / self.warmup_steps))
def cosine_decay(self, step):
return self.cosine_decay_fn(step - self.warmup_steps)
def __call__(self, step):
return tf.where(tf.less(step, self.warmup_steps), self.linear_warmup(step), self.cosine_decay(step))
def get_config(self):
return {
'initial_learning_rate': self.initial_learning_rate,
'warmup_steps': self.warmup_steps,
'total_steps': self.total_steps,
'alpha': self.alpha,
'name': self.name
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment