Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created August 30, 2019 14:39
Show Gist options
  • Save NMZivkovic/026804359807bee8141317ae6f9b2158 to your computer and use it in GitHub Desktop.
Save NMZivkovic/026804359807bee8141317ae6f9b2158 to your computer and use it in GitHub Desktop.
class Schedule(LearningRateSchedule):
def __init__(self, num_neurons, warmup_steps=4000):
super(Schedule, self).__init__()
self.num_neurons = tf.cast(num_neurons, tf.float32)
self.warmup_steps = warmup_steps
def __call__(self, step):
arg1 = tf.math.rsqrt(step)
arg2 = step * (self.warmup_steps ** -1.5)
return tf.math.rsqrt(self.num_neurons) * tf.math.minimum(arg1, arg2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment