|
# See https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html |
|
import torch |
|
import ignite |
|
import numpy as np |
|
import pandas as pd |
|
|
|
# Usage: |
|
# orig_state = serialize_state(model, optimizer) |
|
# def serialize_state(model, optimizer): |
|
# return { |
|
# 'model_training': model.training, |
|
# 'model_state': model.state_dict(), |
|
# 'optim_state': optimizer.state_dict(), |
|
# } |
|
# Usage: |
|
# restore_state(model, optimizer, orig_state) |
|
# def restore_state(model, optimizer, state): |
|
# model.train(state['model_training']) |
|
# model.load_state_dict(state['model_state']) |
|
# optimizer.load_state_dict(state['optim_state']) |
|
|
|
|
|
def update_lrs(optimizer, lrs): |
|
lrs = np.broadcast_to(lrs, len(optimizer.param_groups)) |
|
for group, lr in zip(optimizer.param_groups, lrs): |
|
# for p in group['params']: |
|
group['lr'] = lr |
|
|
|
def smooth_curve(vals, beta): |
|
avg_val = 0 |
|
smoothed = [] |
|
for (i,v) in enumerate(vals): |
|
avg_val = beta * avg_val + (1-beta) * v |
|
smoothed.append(avg_val/(1-beta**(i+1))) |
|
return smoothed |
|
|
|
def find_lr(model, dataloader, optimizer, loss_fn, start_lr=1e-5, end_lr=10, steps='auto', linear=False, beta=0.98, |
|
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), **kwargs): |
|
"""Helps you find an optimal learning rate for a model. |
|
It uses the technique developed in the 2015 paper |
|
`Cyclical Learning Rates for Training Neural Networks`, where |
|
we simply keep increasing the learning rate from a very small value, |
|
until the loss starts decreasing. |
|
Args: |
|
start_lr (float/numpy array) : Passing in a numpy array allows you |
|
to specify learning rates for a learner's layer_groups |
|
end_lr (float) : The maximum learning rate to try. |
|
steps (int, optional): How many steps to take while incrementing the LR. Defaults to the length of the dataloader. |
|
Examples: |
|
As training moves us closer to the optimal weights for a model, |
|
the optimal learning rate will be smaller. We can take advantage of |
|
that knowledge and provide lr_find() with a starting learning rate |
|
1000x smaller than the model's current learning rate as such: |
|
>> learn.lr_find(lr/1000) |
|
>> lrs = np.array([ 1e-4, 1e-3, 1e-2 ]) |
|
>> learn.lr_find(lrs / 1000) |
|
Notes: |
|
lr_find() may finish before going through every batch of examples if |
|
the loss "explodes" enough. |
|
.. _Cyclical Learning Rates for Training Neural Networks: |
|
http://arxiv.org/abs/1506.01186 |
|
""" |
|
trainer = ignite.engine.create_supervised_trainer(model, optimizer, loss_fn, device=device) |
|
num_batches = steps if type(steps) == int else len(dataloader) |
|
|
|
lrs = {} |
|
if linear: |
|
lrs['queue'] = np.linspace(start_lr, end_lr, num=num_batches) |
|
else: |
|
lrs['queue'] = np.logspace(np.log10(start_lr), np.log10(end_lr), num=num_batches) |
|
lrs['current'] = None |
|
lrs['history'] = pd.DataFrame([], columns=['lr', 'loss']) |
|
|
|
def step_lr(optimizer): |
|
lrs['current'], lrs['queue'] = lrs['queue'][0], lrs['queue'][1:] |
|
update_lrs(optimizer, lrs['current']) |
|
|
|
def record_lr_loss(loss): |
|
record = {} |
|
record['lr'] = lrs['current'] |
|
record['loss'] = loss |
|
# prev_moving_avg = lrs['history'].tail(1)['loss_moving_avg'].tolist()[-1] if len(lrs['history']) > 0 else 0 |
|
# record['loss_moving_avg'] = beta * prev_moving_avg + (1-beta) * record['loss'] |
|
# batch_num = len(lrs['history']) + 1 |
|
# record['loss_smoothed'] = record['loss_moving_avg'] / (1 - beta**batch_num) |
|
lrs['history'] = lrs['history'].append(record, ignore_index=True) |
|
|
|
def terminate_on_loss_explosion(trainer): |
|
smoothed = smooth_curve(lrs['history']['loss'].tolist(), beta) |
|
if smoothed[-1] > 4*np.array(smoothed).min(): |
|
print(f'Terminating: Loss is exploding ({smoothed[-1]} > 4 * {np.array(smoothed).min()}).') |
|
trainer.terminate() |
|
|
|
def terminate_on_empty_queue(trainer): |
|
if len(lrs['queue']) == 0: |
|
print(f'Terminating: Reached end of dataloader or max batches.') |
|
trainer.terminate() |
|
|
|
trainer.add_event_handler(ignite.engine.Events.ITERATION_STARTED, lambda trainer: step_lr(optimizer)) |
|
trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, lambda trainer: record_lr_loss(trainer.state.output)) |
|
trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, lambda trainer: terminate_on_loss_explosion(trainer)) |
|
trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, lambda trainer: terminate_on_empty_queue(trainer)) |
|
|
|
from .bind_epoch_tqdm import bind_epoch_tqdm |
|
bind_epoch_tqdm(trainer, desc_fn=lambda trainer: f"lr={lrs['history'].tail(1)['lr'].tolist()[-1]:.3E} loss={lrs['history'].tail(1)['loss'].tolist()[-1]:.3f}") |
|
|
|
trainer.run(dataloader, max_epochs=10) |
|
|
|
lrs['history']['loss_smoothed'] = smooth_curve(lrs['history']['loss'].tolist(), beta) |
|
|
|
return lrs['history'] |
|
|