Skip to content

Instantly share code, notes, and snippets.

@innat
Created June 13, 2022 11:26
Show Gist options
  • Save innat/69e8f3500c2418c69b150a0a651f31dc to your computer and use it in GitHub Desktop.
Save innat/69e8f3500c2418c69b150a0a651f31dc to your computer and use it in GitHub Desktop.
class WarmupLearningRateSchedule(optimizers.schedules.LearningRateSchedule):
"""WarmupLearningRateSchedule a variety of learning rate
decay schedules with warm up."""
def __init__(
self,
initial_lr,
steps_per_epoch=None,
lr_decay_type="exponential",
decay_factor=0.97,
decay_epochs=2.4,
total_steps=None,
warmup_epochs=5,
minimal_lr=0,
):
super(WarmupLearningRateSchedule, self).__init__()
self.initial_lr = initial_lr
self.steps_per_epoch = steps_per_epoch
self.lr_decay_type = lr_decay_type
self.decay_factor = decay_factor
self.decay_epochs = decay_epochs
self.total_steps = total_steps
self.warmup_epochs = warmup_epochs
self.minimal_lr = minimal_lr
def __call__(self, step):
if self.lr_decay_type == "exponential":
assert self.steps_per_epoch is not None
decay_steps = self.steps_per_epoch * self.decay_epochs
lr = schedules.ExponentialDecay(
self.initial_lr, decay_steps, self.decay_factor, staircase=True
)(step)
elif self.lr_decay_type == "cosine":
assert self.total_steps is not None
lr = (
0.5
* self.initial_lr
* (1 + tf.cos(np.pi * tf.cast(step, tf.float32) / self.total_steps))
)
elif self.lr_decay_type == "linear":
assert self.total_steps is not None
lr = (1.0 - tf.cast(step, tf.float32) / self.total_steps) * self.initial_lr
elif self.lr_decay_type == "constant":
lr = self.initial_lr
elif self.lr_decay_type == "cosine_restart":
decay_steps = self.steps_per_epoch * self.decay_epochs
lr = tf.keras.experimental.CosineDecayRestarts(
self.initial_lr, decay_steps
)(step)
else:
assert False, "Unknown lr_decay_type : %s" % self.lr_decay_type
if self.minimal_lr:
lr = tf.math.maximum(lr, self.minimal_lr)
if self.warmup_epochs:
warmup_steps = int(self.warmup_epochs * self.steps_per_epoch)
warmup_lr = (
self.initial_lr
* tf.cast(step, tf.float32)
/ tf.cast(warmup_steps, tf.float32)
)
lr = tf.cond(step < warmup_steps, lambda: warmup_lr, lambda: lr)
return lr
def get_config(self):
return {
"initial_lr": self.initial_lr,
"steps_per_epoch": self.steps_per_epoch,
"lr_decay_type": self.lr_decay_type,
"decay_factor": self.decay_factor,
"decay_epochs": self.decay_epochs,
"total_steps": self.total_steps,
"warmup_epochs": self.warmup_epochs,
"minimal_lr": self.minimal_lr,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment