Last active
February 18, 2022 11:37
-
-
Save crowsonkb/028cd69b0f40d911f0c3c07776b9606f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Learning rate and EMA warmup schedulers for PyTorch.""" | |
import warnings | |
from torch import optim | |
class InverseLR(optim.lr_scheduler._LRScheduler): | |
"""Implements an inverse decay learning rate schedule with an optional exponential | |
warmup. When last_epoch=-1, sets initial lr as lr. | |
1 / gamma is the number of steps/epochs required for the learning rate to decay to | |
(1 / 2)**power of its original value. | |
Args: | |
optimizer (Optimizer): Wrapped optimizer. | |
gamma (float): Multiplicative factor of learning rate decay. Default: 1. | |
power (float): Exponential factor of learning rate decay. Default: 1. | |
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) | |
Default: 0. | |
final_lr (float): The final learning rate. Default: 0. | |
last_epoch (int): The index of last epoch. Default: -1. | |
verbose (bool): If ``True``, prints a message to stdout for | |
each update. Default: ``False``. | |
""" | |
def __init__(self, optimizer, gamma=1., power=1., warmup=0., final_lr=0., | |
last_epoch=-1, verbose=False): | |
self.gamma = gamma | |
self.power = power | |
if not 0. <= warmup < 1: | |
raise ValueError('Invalid value for warmup') | |
self.warmup = warmup | |
self.final_lr = final_lr | |
super().__init__(optimizer, last_epoch, verbose) | |
def get_lr(self): | |
if not self._get_lr_called_within_step: | |
warnings.warn("To get the last learning rate computed by the scheduler, " | |
"please use `get_last_lr()`.") | |
return self._get_closed_form_lr() | |
def _get_closed_form_lr(self): | |
warmup = 1 - self.warmup ** self.last_epoch | |
lr_mult = (1 + self.gamma * self.last_epoch) ** -self.power | |
return [warmup * max(self.final_lr, base_lr * lr_mult) | |
for base_lr in self.base_lrs] | |
class EMAWarmup: | |
"""Implements an EMA warmup using an inverse decay schedule. | |
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are | |
good values for models you plan to train for a million or more steps (reaches decay | |
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models | |
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at | |
215.4k steps). | |
Args: | |
gamma (float): Multiplicative factor of EMA warmup. Default: 1. | |
power (float): Exponential factor of EMA warmup. Default: 1. | |
min_value (float): The minimum EMA decay rate. Default: 0. | |
max_value (float): The maximum EMA decay rate. Default: 1. | |
start_at (int): The epoch to start averaging at. Default: 0. | |
last_epoch (int): The index of last epoch. Default: 0. | |
""" | |
def __init__(self, gamma=1., power=1., min_value=0., max_value=1., start_at=0, | |
last_epoch=0): | |
self.gamma = gamma | |
self.power = power | |
self.min_value = min_value | |
self.max_value = max_value | |
self.start_at = start_at | |
self.last_epoch = last_epoch | |
def state_dict(self): | |
"""Returns the state of the class as a :class:`dict`.""" | |
return dict(self.__dict__.items()) | |
def load_state_dict(self, state_dict): | |
"""Loads the class's state. | |
Args: | |
state_dict (dict): scaler state. Should be an object returned | |
from a call to :meth:`state_dict`. | |
""" | |
self.__dict__.update(state_dict) | |
def get_value(self): | |
"""Gets the current EMA decay rate.""" | |
epoch = max(0, self.last_epoch - self.start_at) | |
value = 1 - (1 + self.gamma * epoch) ** -self.power | |
return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value)) | |
def step(self): | |
"""Updates the step count.""" | |
self.last_epoch += 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment