Last active
March 2, 2019 08:30
-
-
Save wassname/168b2167ac18c0671f41ae9f6fb86e66 to your computer and use it in GitHub Desktop.
A learning rate scheduler for pytorch which interpolates on log or linear scales
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.optim.lr_scheduler import _LRScheduler | |
import numpy as np | |
class InterpolatingScheduler(_LRScheduler): | |
def __init__(self, optimizer, steps, lrs, scale='log', last_epoch=-1): | |
"""A scheduler that interpolates given values | |
Args: | |
- optimizer: pytorch optimizer | |
- steps: list or array with the x coordinates of the interpolated values | |
- lrs: list or array with the learning rates corresponding to the steps | |
- scale: one of ['linear', 'log'] the scale on which to interpolate. Log is usefull since learning rates operate on a logarithmic scale. | |
Usage: | |
fc = nn.Linear(1,1) | |
optimizer = optim.Adam(fc.parameters()) | |
lr_scheduler = InterpolatingScheduler(optimizer, steps=[0, 100, 400], lrs=[1e-6, 1e-4, 1e-8], scale='log') | |
""" | |
self.scale = scale | |
self.steps = steps | |
self.lrs = lrs | |
super().__init__(optimizer, last_epoch) | |
def get_lr(self): | |
x = [self.last_epoch] | |
if self.scale=='linear': | |
y = np.interp(x, self.steps, self.lrs) | |
elif self.scale=='log': | |
y = np.interp(x, self.steps, np.log(self.lrs)) | |
y = np.exp(y) | |
else: | |
raise ValueError("scale should be one of ['linear', 'log']") | |
return [y[0] for lr in self.base_lrs] | |
# Example of use | |
import torch.optim as optim | |
from torch import nn | |
fc = nn.Linear(1,1) | |
optimizer = optim.Adam(fc.parameters()) | |
lr_scheduler = InterpolatingScheduler(optimizer, steps=[0, 100, 400, 800], lrs=[1e-6, 1e-4, 1e-8, 1e-9], scale='log') | |
# plot the lr schedule | |
x=np.linspace(0, 1000, 6000) | |
y=[] | |
for xx in x: | |
lr_scheduler.last_epoch=xx | |
lry = lr_scheduler.get_lr()[0] | |
y.append(lry) | |
lr_scheduler.last_epoch=-1 | |
plt.figure() | |
plt.plot(x,y) | |
plt.title('InterpolatingScheduler') | |
plt.yscale('log') | |
plt.xlabel('epoch') | |
plt.ylabel('lr') | |
plt.show() |
Author
wassname
commented
Aug 12, 2018
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment