Skip to content

Instantly share code, notes, and snippets.

@colllin
Last active June 4, 2023 11:00
Show Gist options
  • Save colllin/738cd2a9f0abec9be5e8b9becc23a812 to your computer and use it in GitHub Desktop.
Save colllin/738cd2a9f0abec9be5e8b9becc23a812 to your computer and use it in GitHub Desktop.
Learning Rate Finder in PyTorch

Notes

  • You'll need to bring your own functions which initialize a fresh optimizer, dataloaders, return a loss function, etc.
  • You'll want to look through the cell which runs the LR finder and consider adjusting...
    • which parameters are tested, i.e. beta1 and wd might not be valid arguments for your optimizer. I recommend only varying one or two parameters at a time.
    • which parameters values are tested, i.e. beta1 in [0.85, 0.95] and wd=0.1. The LR finder is run 3 times for each combination of parameter values, so I recommend restricting to 4 combinations at a time, and repeating as necessary.
    • which range of learning rates is tested, i.e. start_lr=1e-6 and end_lr=1e-3. I recommend starting with a longer range for a small initial test, e.g. 1e-6 to 1e0, and then adjusting to the useful range for further tests.
    • how many steps are taken across this range, i.e. steps=100. I recommend roughly 50 steps per order of magnitude, but in general fewer steps will run faster, so choose the lowest value which gives you useful results.

TODO

  • Remove ignite dependency
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
# See https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
import torch
import ignite
import numpy as np
import pandas as pd
# Usage:
# orig_state = serialize_state(model, optimizer)
# def serialize_state(model, optimizer):
# return {
# 'model_training': model.training,
# 'model_state': model.state_dict(),
# 'optim_state': optimizer.state_dict(),
# }
# Usage:
# restore_state(model, optimizer, orig_state)
# def restore_state(model, optimizer, state):
# model.train(state['model_training'])
# model.load_state_dict(state['model_state'])
# optimizer.load_state_dict(state['optim_state'])
def update_lrs(optimizer, lrs):
lrs = np.broadcast_to(lrs, len(optimizer.param_groups))
for group, lr in zip(optimizer.param_groups, lrs):
# for p in group['params']:
group['lr'] = lr
def smooth_curve(vals, beta):
avg_val = 0
smoothed = []
for (i,v) in enumerate(vals):
avg_val = beta * avg_val + (1-beta) * v
smoothed.append(avg_val/(1-beta**(i+1)))
return smoothed
def find_lr(model, dataloader, optimizer, loss_fn, start_lr=1e-5, end_lr=10, steps='auto', linear=False, beta=0.98,
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), **kwargs):
"""Helps you find an optimal learning rate for a model.
It uses the technique developed in the 2015 paper
`Cyclical Learning Rates for Training Neural Networks`, where
we simply keep increasing the learning rate from a very small value,
until the loss starts decreasing.
Args:
start_lr (float/numpy array) : Passing in a numpy array allows you
to specify learning rates for a learner's layer_groups
end_lr (float) : The maximum learning rate to try.
steps (int, optional): How many steps to take while incrementing the LR. Defaults to the length of the dataloader.
Examples:
As training moves us closer to the optimal weights for a model,
the optimal learning rate will be smaller. We can take advantage of
that knowledge and provide lr_find() with a starting learning rate
1000x smaller than the model's current learning rate as such:
>> learn.lr_find(lr/1000)
>> lrs = np.array([ 1e-4, 1e-3, 1e-2 ])
>> learn.lr_find(lrs / 1000)
Notes:
lr_find() may finish before going through every batch of examples if
the loss "explodes" enough.
.. _Cyclical Learning Rates for Training Neural Networks:
http://arxiv.org/abs/1506.01186
"""
trainer = ignite.engine.create_supervised_trainer(model, optimizer, loss_fn, device=device)
num_batches = steps if type(steps) == int else len(dataloader)
lrs = {}
if linear:
lrs['queue'] = np.linspace(start_lr, end_lr, num=num_batches)
else:
lrs['queue'] = np.logspace(np.log10(start_lr), np.log10(end_lr), num=num_batches)
lrs['current'] = None
lrs['history'] = pd.DataFrame([], columns=['lr', 'loss'])
def step_lr(optimizer):
lrs['current'], lrs['queue'] = lrs['queue'][0], lrs['queue'][1:]
update_lrs(optimizer, lrs['current'])
def record_lr_loss(loss):
record = {}
record['lr'] = lrs['current']
record['loss'] = loss
# prev_moving_avg = lrs['history'].tail(1)['loss_moving_avg'].tolist()[-1] if len(lrs['history']) > 0 else 0
# record['loss_moving_avg'] = beta * prev_moving_avg + (1-beta) * record['loss']
# batch_num = len(lrs['history']) + 1
# record['loss_smoothed'] = record['loss_moving_avg'] / (1 - beta**batch_num)
lrs['history'] = lrs['history'].append(record, ignore_index=True)
def terminate_on_loss_explosion(trainer):
smoothed = smooth_curve(lrs['history']['loss'].tolist(), beta)
if smoothed[-1] > 4*np.array(smoothed).min():
print(f'Terminating: Loss is exploding ({smoothed[-1]} > 4 * {np.array(smoothed).min()}).')
trainer.terminate()
def terminate_on_empty_queue(trainer):
if len(lrs['queue']) == 0:
print(f'Terminating: Reached end of dataloader or max batches.')
trainer.terminate()
trainer.add_event_handler(ignite.engine.Events.ITERATION_STARTED, lambda trainer: step_lr(optimizer))
trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, lambda trainer: record_lr_loss(trainer.state.output))
trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, lambda trainer: terminate_on_loss_explosion(trainer))
trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, lambda trainer: terminate_on_empty_queue(trainer))
from .bind_epoch_tqdm import bind_epoch_tqdm
bind_epoch_tqdm(trainer, desc_fn=lambda trainer: f"lr={lrs['history'].tail(1)['lr'].tolist()[-1]:.3E} loss={lrs['history'].tail(1)['loss'].tolist()[-1]:.3f}")
trainer.run(dataloader, max_epochs=10)
lrs['history']['loss_smoothed'] = smooth_curve(lrs['history']['loss'].tolist(), beta)
return lrs['history']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment