class WarmupLearningRateSchedule(optimizers.schedules.LearningRateSchedule): """WarmupLearningRateSchedule a variety of learning rate decay schedules with warm up.""" def __init__( self, initial_lr, steps_per_epoch=None, lr_decay_type="exponential", decay_factor=0.97, decay_epochs=2.4, total_steps=None, warmup_epochs=5, minimal_lr=0, ): super(WarmupLearningRateSchedule, self).__init__() self.initial_lr = initial_lr self.steps_per_epoch = steps_per_epoch self.lr_decay_type = lr_decay_type self.decay_factor = decay_factor self.decay_epochs = decay_epochs self.total_steps = total_steps self.warmup_epochs = warmup_epochs self.minimal_lr = minimal_lr def __call__(self, step): if self.lr_decay_type == "exponential": assert self.steps_per_epoch is not None decay_steps = self.steps_per_epoch * self.decay_epochs lr = schedules.ExponentialDecay( self.initial_lr, decay_steps, self.decay_factor, staircase=True )(step) elif self.lr_decay_type == "cosine": assert self.total_steps is not None lr = ( 0.5 * self.initial_lr * (1 + tf.cos(np.pi * tf.cast(step, tf.float32) / self.total_steps)) ) elif self.lr_decay_type == "linear": assert self.total_steps is not None lr = (1.0 - tf.cast(step, tf.float32) / self.total_steps) * self.initial_lr elif self.lr_decay_type == "constant": lr = self.initial_lr elif self.lr_decay_type == "cosine_restart": decay_steps = self.steps_per_epoch * self.decay_epochs lr = tf.keras.experimental.CosineDecayRestarts( self.initial_lr, decay_steps )(step) else: assert False, "Unknown lr_decay_type : %s" % self.lr_decay_type if self.minimal_lr: lr = tf.math.maximum(lr, self.minimal_lr) if self.warmup_epochs: warmup_steps = int(self.warmup_epochs * self.steps_per_epoch) warmup_lr = ( self.initial_lr * tf.cast(step, tf.float32) / tf.cast(warmup_steps, tf.float32) ) lr = tf.cond(step < warmup_steps, lambda: warmup_lr, lambda: lr) return lr def get_config(self): return { "initial_lr": self.initial_lr, "steps_per_epoch": self.steps_per_epoch, "lr_decay_type": self.lr_decay_type, "decay_factor": self.decay_factor, "decay_epochs": self.decay_epochs, "total_steps": self.total_steps, "warmup_epochs": self.warmup_epochs, "minimal_lr": self.minimal_lr, }