Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active February 18, 2022 11:37
Show Gist options
  • Save crowsonkb/028cd69b0f40d911f0c3c07776b9606f to your computer and use it in GitHub Desktop.
Save crowsonkb/028cd69b0f40d911f0c3c07776b9606f to your computer and use it in GitHub Desktop.
"""Learning rate and EMA warmup schedulers for PyTorch."""
import warnings
from torch import optim
class InverseLR(optim.lr_scheduler._LRScheduler):
"""Implements an inverse decay learning rate schedule with an optional exponential
warmup. When last_epoch=-1, sets initial lr as lr.
1 / gamma is the number of steps/epochs required for the learning rate to decay to
(1 / 2)**power of its original value.
Args:
optimizer (Optimizer): Wrapped optimizer.
gamma (float): Multiplicative factor of learning rate decay. Default: 1.
power (float): Exponential factor of learning rate decay. Default: 1.
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
Default: 0.
final_lr (float): The final learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
"""
def __init__(self, optimizer, gamma=1., power=1., warmup=0., final_lr=0.,
last_epoch=-1, verbose=False):
self.gamma = gamma
self.power = power
if not 0. <= warmup < 1:
raise ValueError('Invalid value for warmup')
self.warmup = warmup
self.final_lr = final_lr
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
warmup = 1 - self.warmup ** self.last_epoch
lr_mult = (1 + self.gamma * self.last_epoch) ** -self.power
return [warmup * max(self.final_lr, base_lr * lr_mult)
for base_lr in self.base_lrs]
class EMAWarmup:
"""Implements an EMA warmup using an inverse decay schedule.
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
gamma (float): Multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 1.
min_value (float): The minimum EMA decay rate. Default: 0.
max_value (float): The maximum EMA decay rate. Default: 1.
start_at (int): The epoch to start averaging at. Default: 0.
last_epoch (int): The index of last epoch. Default: 0.
"""
def __init__(self, gamma=1., power=1., min_value=0., max_value=1., start_at=0,
last_epoch=0):
self.gamma = gamma
self.power = power
self.min_value = min_value
self.max_value = max_value
self.start_at = start_at
self.last_epoch = last_epoch
def state_dict(self):
"""Returns the state of the class as a :class:`dict`."""
return dict(self.__dict__.items())
def load_state_dict(self, state_dict):
"""Loads the class's state.
Args:
state_dict (dict): scaler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_value(self):
"""Gets the current EMA decay rate."""
epoch = max(0, self.last_epoch - self.start_at)
value = 1 - (1 + self.gamma * epoch) ** -self.power
return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
def step(self):
"""Updates the step count."""
self.last_epoch += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment