Created
May 1, 2021 12:28
-
-
Save ryanholbrook/fe3984d5a1c4fc7c36de3b02536ea866 to your computer and use it in GitHub Desktop.
Keras Optax Schedules
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import tensorflow.experimental.numpy as tnp | |
from absl import logging | |
from typing import Callable, Dict, Union, Optional, Iterable, Sequence | |
from tensorflow import keras | |
from tensorflow.keras.optimizers.schedules import LearningRateSchedule | |
# Schedules ported from Optax | |
# https://github.com/deepmind/optax/blob/master/optax/_src/schedule.py | |
class ConstantSchedule(LearningRateSchedule): | |
"""Constructs a constant schedule. | |
Args: | |
value: value to be held constant throughout. | |
Returns: | |
schedule: A function that maps step counts to values. | |
""" | |
def __init__( | |
self, | |
value: Union[float, int], | |
): | |
self.value = value | |
def __call__(self, step): | |
return tf.constant(self.value, shape=tf.convert_to_tensor(step).shape) | |
class PolynomialSchedule(LearningRateSchedule): | |
"""Constructs a schedule with polynomial transition from init to end value. | |
Args: | |
init_value: initial value for the scalar to be annealed. | |
end_value: end value of the scalar to be annealed. | |
power: the power of the polynomial used to transition from init to end. | |
transition_steps: number of steps over which annealing takes place, | |
the scalar starts changing at `transition_begin` steps and completes | |
the transition by `transition_begin + transition_steps` steps. | |
If `transition_steps <= 0`, then the entire annealing process is disabled | |
and the value is held fixed at `init_value`. | |
transition_begin: must be positive. After how many steps to start annealing | |
(before this many steps the scalar value is held fixed at `init_value`). | |
Returns: | |
schedule: A function that maps step counts to values. | |
""" | |
def __init__( | |
self, | |
init_value, | |
end_value, | |
power, | |
transition_steps: int, | |
transition_begin: int = 0, | |
): | |
self.init_value = init_value | |
self.end_value = end_value | |
self.power = power | |
self.transition_steps = transition_steps | |
self.transition_begin = transition_begin | |
if self.transition_steps <= 0: | |
logging.info( | |
'A polynomial schedule was set with a non-positive `transition_steps` ' | |
'value; this results in a constant schedule with value `init_value`.' | |
) | |
if transition_begin < 0: | |
logging.info( | |
'An polynomial schedule was set with a negative `transition_begin` ' | |
'value; this will result in `transition_begin` falling back to `0`.' | |
) | |
self.transition_begin = 0 | |
def __call__(self, step): | |
if self.transition_steps <= 0: | |
return self.init_value | |
step = tnp.clip( | |
step - self.transition_begin, | |
0, | |
self.transition_steps, | |
) | |
frac = 1 - step / self.transition_steps | |
return ((self.init_value - self.end_value) * (frac**self.power) + | |
self.end_value) | |
class LinearSchedule(PolynomialSchedule): | |
"""Constructs a `PolynomialSchedule` with `power=1`.""" | |
def __init__( | |
self, | |
init_value: int, | |
end_value: int, | |
transition_steps: int, | |
transition_begin: int = 0, | |
): | |
super().__init__( | |
init_value=init_value, | |
end_value=end_value, | |
transition_steps=transition_steps, | |
transition_begin=transition_begin, | |
power=1, | |
) | |
class PiecewiseConstantSchedule(LearningRateSchedule): | |
"""Returns a function which implements a piecewise constant schedule. | |
Args: | |
init_value: An initial value `init_v`. | |
boundaries_and_scales: A map from boundaries `b_i` to non-negative scaling | |
factors `f_i`. For any step count `s`, the schedule returns `init_v` | |
scaled by the product of all factors `f_i` such that `b_i` < `s`. | |
Returns: | |
schedule: A function that maps step counts to values. | |
""" | |
def __init__( | |
self, | |
init_value: float, | |
boundaries_and_scales: Optional[Dict[int, float]] = None, | |
): | |
if boundaries_and_scales is not None: | |
all_positive = all(scale >= 0.0 | |
for scale in boundaries_and_scales.values()) | |
if not all_positive: | |
raise ValueError( | |
'`PiecewiseConstantSchedule` expects non-negative scale factors' | |
) | |
self.init_value = init_value | |
self.boundaries_and_scales = boundaries_and_scales | |
def __call__(self, step): | |
v = self.init_value | |
if self.boundaries_and_scales is not None: | |
for threshold, scale in sorted(self.boundaries_and_scales.items()): | |
indicator = tf.maximum(0., tnp.sign(threshold - step)) | |
v = v * indicator + (1 - indicator) * scale * v | |
return v | |
class ExponentialDecaySchedule(LearningRateSchedule): | |
"""Constructs a schedule with either continuous or discrete exponential decay. | |
This function applies an exponential decay function to a provided initial | |
value. The function returns the decayed value as follows: | |
``` | |
decayed_value = init_value * decay_rate ^ (count / transition_steps) | |
``` | |
If the argument `staircase` is `True`, then `count / transition_steps` is | |
an integer division and the decayed value follows a staircase function. | |
Args: | |
init_value: the initial learning rate. | |
transition_steps: must be positive. See the decay computation above. | |
decay_rate: must not be zero. The decay rate. | |
transition_begin: must be positive. After how many steps to start annealing | |
(before this many steps the scalar value is held fixed at `init_value`). | |
staircase: if `True`, decay the values at discrete intervals. | |
end_value: the value at which the exponential decay stops. When | |
`decay_rate` < 1, `end_value` is treated as a lower bound, otherwise as | |
an upper bound. Has no effect when `decay_rate` = 0. | |
Returns: | |
schedule: A function that maps step counts to values. | |
""" | |
def __init__( | |
self, | |
init_value: float, | |
transition_steps: int, | |
decay_rate: float, | |
transition_begin: int = 0, | |
staircase: bool = False, | |
end_value: Optional[float] = None, | |
) -> LearningRateSchedule: | |
if transition_steps <= 0: | |
logging.info( | |
'An exponential schedule was set with a non-positive `transition_steps`' | |
' value; this will result in a constant schedule with value ' | |
'`init_value`.') | |
if decay_rate == 0: | |
logging.info( | |
'An exponential schedule was set with a zero `decay_rate` value; ' | |
'this will result in a constant schedule with value `init_value`.' | |
) | |
if transition_begin < 0: | |
logging.info( | |
'An exponential schedule was set with a negative `transition_begin` ' | |
'value; this will result in `transition_begin` falling back to `0`.' | |
) | |
self.transition_begin = 0 | |
if end_value is not None: | |
self.clip_fn = tnp.maximum if decay_rate < 1.0 else tnp.minimum | |
self.init_value = init_value | |
self.transition_steps = transition_steps | |
self.decay_rate = decay_rate | |
self.transition_begin = transition_begin | |
self.staircase = staircase | |
self.end_value = end_value | |
def __call__(self, step): | |
decreased_count = step - self.transition_begin | |
p = decreased_count / self.transition_steps | |
if self.staircase: | |
p = tnp.floor(p) | |
decayed_value = tnp.where( | |
decreased_count <= 0, | |
self.init_value, | |
self.init_value * tnp.power(self.decay_rate, p), | |
) | |
if self.end_value is not None: | |
decayed_value = self.clip_fn(decayed_value, self.end_value) | |
return decayed_value | |
class CosineDecaySchedule(LearningRateSchedule): | |
"""Returns a function which implements cosine learning rate decay. | |
For more details see: https://arxiv.org/abs/1608.03983 | |
Args: | |
init_value: An initial value `init_v`. | |
decay_steps: Positive integer - the number of steps for which to apply | |
the decay for. | |
alpha: Float. The minimum value of the multiplier used to adjust the | |
learning rate. | |
Returns: | |
schedule: A function that maps step counts to values. | |
""" | |
def __init__( | |
self, | |
init_value: float, | |
decay_steps: int, | |
alpha: float = 0.0, | |
) -> LearningRateSchedule: | |
if not decay_steps > 0: | |
raise ValueError( | |
'The cosine_decay_schedule requires positive decay_steps!') | |
self.init_value = init_value | |
self.decay_steps = decay_steps | |
self.alpha = alpha | |
def __call__(self, step): | |
step = tnp.minimum(step, self.decay_steps) | |
cosine_decay = 0.5 * (1 + tnp.cos(tnp.pi * step / self.decay_steps)) | |
decayed = (1 - self.alpha) * cosine_decay + self.alpha | |
return self.init_value * decayed | |
def _linear_interpolate(start: float, end: float, pct: float): | |
return (end - start) * pct + start | |
def _cosine_interpolate(start: float, end: float, pct: float): | |
return end + (start - end) / 2.0 * (tnp.cos(tnp.pi * pct) + 1) | |
class PiecewiseInterpolateSchedule(LearningRateSchedule): | |
"""Returns a function which implements a piecewise interpolated schedule. | |
Args: | |
interpolate_type: 'linear' or 'cosine', specifying the interpolation | |
strategy. | |
init_value: An initial value `init_v`. | |
boundaries_and_scales: A map from boundaries `b_i` to non-negative scaling | |
factors `f_i`. At boundary step `b_i`, the schedule returns `init_v` | |
scaled by the product of all factors `f_j` such that `b_j` < `b_i`. The | |
values in between each boundary will be interpolated as per `type`. | |
Returns: | |
schedule: A function that maps step counts to values. | |
""" | |
def __init__( | |
self, | |
interpolate_type: str, | |
init_value: float, | |
boundaries_and_scales: Optional[Dict[int, float]] = None | |
) -> LearningRateSchedule: | |
self.interpolate_type = interpolate_type | |
self.init_value = init_value | |
self.boundaries_and_scales = boundaries_and_scales | |
if interpolate_type == 'linear': | |
self.interpolate_fn = _linear_interpolate | |
elif interpolate_type == 'cosine': | |
self.interpolate_fn = _cosine_interpolate | |
else: | |
raise ValueError( | |
'`interpolate_type` must be either \'cos\' or \'linear\'') | |
if boundaries_and_scales is not None: | |
self.boundaries, self.scales = zip( | |
*sorted(boundaries_and_scales.items())) | |
if not all(scale >= 0. for scale in self.scales): | |
raise ValueError( | |
'`piecewise_interpolate_schedule` expects non-negative scale factors' | |
) | |
else: | |
self.boundaries, self.scales = (), () | |
self.bounds = tnp.stack((0, ) + self.boundaries) | |
self.values = tnp.cumprod(tnp.stack((self.init_value, ) + self.scales)) | |
self.interval_sizes = (self.bounds[1:] - self.bounds[:-1]) | |
def __call__(self, step): | |
indicator = (tf.cast(self.bounds[:-1] <= step, tf.int8) * | |
tf.cast(step < self.bounds[1:], tf.int8)) | |
pct = (step - self.bounds[:-1]) / self.interval_sizes | |
interp_vals = self.interpolate_fn( | |
self.values[:-1], | |
self.values[1:], | |
pct, | |
) | |
return (tnp.dot(indicator, interp_vals) + | |
(self.bounds[-1] <= step) * self.values[-1]) | |
class LinearOneCycleSchedule(PiecewiseInterpolateSchedule): | |
def __init__( | |
self, | |
transition_steps: int, | |
peak_value: float, | |
pct_start: float = 0.3, | |
pct_final: float = 0.85, | |
div_factor: float = 25.0, | |
final_div_factor: float = 1e4, | |
) -> LearningRateSchedule: | |
"""Returns a function which implements the onecycle learning rate schedule. | |
This function uses a linear annealing strategy. | |
For more details see: https://arxiv.org/abs/1708.07120 | |
Args: | |
transition_steps: Number of steps over which annealing takes place. | |
peak_value: Maximum value attained by schedule at pct_start percent | |
of the cycle (in number of steps). | |
pct_start: The percentage of the cycle (in number of steps) spent | |
increasing the learning rate. | |
pct_final: The percentage of the cycle (in number of steps) spent | |
increasing to peak_value then decreasing back to init_value. | |
div_factor: Determines the initial value via init_value = | |
peak_value / div_factor | |
final_div_factor: Determines the final value via final_value = | |
init_value / final_div_factor | |
Returns: | |
schedule: A function that maps step counts to values. | |
""" | |
if transition_steps <= 0: | |
raise ValueError( | |
'A linear onecycle schedule was set with a non-positive ' | |
'`transition_steps`') | |
super().__init__( | |
interpolate_type='linear', | |
init_value=peak_value / div_factor, | |
boundaries_and_scales={ | |
int(pct_start * transition_steps): div_factor, | |
int(pct_final * transition_steps): 1. / div_factor, | |
transition_steps: 1. / final_div_factor | |
}, | |
) | |
class CosineOneCycleSchedule(PiecewiseInterpolateSchedule): | |
"""Returns a function which implements the onecycle learning rate schedule. | |
This function uses a cosine annealing strategy. | |
For more details see: https://arxiv.org/abs/1708.07120 | |
Args: | |
transition_steps: Number of steps over which annealing takes place. | |
peak_value: Maximum value attained by schedule at pct_start percent | |
of the cycle (in number of steps). | |
pct_start: The percentage of the cycle (in number of steps) spent | |
increasing the learning rate. | |
div_factor: Determines the initial value via init_value = | |
peak_value / div_factor | |
final_div_factor: Determines the final value via final_value = | |
init_value / final_div_factor | |
Returns: | |
schedule: A function that maps step counts to values. | |
""" | |
def __init__( | |
self, | |
transition_steps: int, | |
peak_value: float, | |
pct_start: float = 0.3, | |
div_factor: float = 25.0, | |
final_div_factor: float = 1e4, | |
) -> LearningRateSchedule: | |
if transition_steps <= 0: | |
raise ValueError( | |
'A linear onecycle schedule was set with a non-positive ' | |
'`transition_steps`') | |
super().__init__( | |
interpolate_type='cosine', | |
init_value=peak_value / div_factor, | |
boundaries_and_scales={ | |
int(pct_start * transition_steps): div_factor, | |
int(transition_steps): 1. / (div_factor * final_div_factor) | |
}, | |
) | |
class JoinedSchedule(LearningRateSchedule): | |
"""Sequentially apply multiple schedules. | |
Args: | |
schedules: A list of callables (expected to be optax schedules). Each | |
schedule will receive a step count indicating the number of steps since | |
the previous boundary transition. | |
boundaries: A list of integers (of length one less than schedules) that | |
indicate when to transition between schedules. | |
Returns: | |
schedule: A function that maps step counts to values. | |
""" | |
def __init__( | |
self, | |
schedules: Sequence[LearningRateSchedule], | |
boundaries: Sequence[int], | |
): | |
self.schedules = schedules | |
self.boundaries = boundaries | |
def __call__(self, step): | |
lr = self.schedules[0](step) | |
for boundary, schedule in zip(self.boundaries, self.schedules[1:]): | |
lr = tf.where(step < boundary, lr, schedule(step - boundary)) | |
return lr | |
class WarmupCosineDecaySchedule(JoinedSchedule): | |
"""Linear warmup followed by cosine decay. | |
Args: | |
init_value: Initial value for the scalar to be annealed. | |
peak_value: Peak value for scalar to be annealed at end of warmup. | |
warmup_steps: Positive integer, the length of the linear warmup. | |
decay_steps: Positive integer, the total length of the schedule. Note that | |
this includes the warmup time, so the number of steps during which cosine | |
annealing is applied is `decay_steps - warmup_steps`. | |
end_value: End value of the scalar to be annealed. | |
Returns: | |
schedule: A function that maps step counts to values. | |
""" | |
def __init__( | |
self, | |
init_value: float, | |
peak_value: float, | |
warmup_steps: int, | |
decay_steps: int, | |
end_value: float = 0.0, | |
) -> LearningRateSchedule: | |
schedules = [ | |
LinearSchedule(init_value=init_value, | |
end_value=peak_value, | |
transition_steps=warmup_steps), | |
CosineDecaySchedule(init_value=peak_value, | |
decay_steps=decay_steps - warmup_steps, | |
alpha=end_value / peak_value) | |
] | |
super().__init__( | |
schedules=schedules, | |
boundaries=[warmup_steps], | |
) | |
class WarmupExponentialDecaySchedule(JoinedSchedule): | |
"""Linear warmup followed by exponential decay. | |
Args: | |
init_value: Initial value for the scalar to be annealed. | |
peak_value: Peak value for scalar to be annealed at end of warmup. | |
warmup_steps: Positive integer, the length of the linear warmup. | |
transition_steps: must be positive. See the decay computation above. | |
decay_rate: must not be zero. The decay rate. | |
transition_begin: must be positive. After how many steps to start annealing | |
(before this many steps the scalar value is held fixed at `init_value`). | |
staircase: if `True`, decay the values at discrete intervals. | |
end_value: the value at which the exponential decay stops. When | |
`decay_rate` < 1, `end_value` is treated as a lower bound, otherwise as | |
an upper bound. Has no effect when `decay_rate` = 0. | |
Returns: | |
schedule: A function that maps step counts to values. | |
""" | |
def __init__( | |
self, | |
init_value: float, | |
peak_value: float, | |
warmup_steps: int, | |
transition_steps: int, | |
decay_rate: float, | |
transition_begin: int = 0, | |
staircase: bool = False, | |
end_value: Optional[float] = None, | |
) -> LearningRateSchedule: | |
schedules = [ | |
LinearSchedule(init_value=init_value, | |
end_value=peak_value, | |
transition_steps=warmup_steps), | |
ExponentialDecaySchedule(init_value=peak_value, | |
transition_steps=transition_steps, | |
decay_rate=decay_rate, | |
transition_begin=transition_begin, | |
staircase=staircase, | |
end_value=end_value) | |
] | |
super().__init__( | |
schedules=schedules, | |
boundaries=[warmup_steps], | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment