Skip to content

Instantly share code, notes, and snippets.

@devxpy
Last active December 28, 2019 12:00
Show Gist options
  • Save devxpy/b788661d5f5b548124ba3d6fb1813c5a to your computer and use it in GitHub Desktop.
Save devxpy/b788661d5f5b548124ba3d6fb1813c5a to your computer and use it in GitHub Desktop.
#
# 🙏 to prior art: http://puzzlemusa.com/2018/05/14/learning-rate-finder-using-keras/
#
# Example:
# finder = find_lr(model, 1e-10, 1e-2)
# finder.plot()
from bokeh.plotting import figure, show
from bokeh.io import output_notebook
from bokeh.models import HoverTool
def find_lr(model, start_lr, end_lr, stepsize=len(train_df) // batch_size):
finder = LRFinder(start_lr, end_lr, stepsize)
weights = model.get_weights()
try:
history = model.fit_generator(
generator=train_gen,
validation_data=val_gen,
epochs=1, verbose=1, callbacks=[finder],
)
finally:
model.set_weights(weights)
return finder
class LRFinder(Callback):
def __init__(self, start_lr, end_lr, stepsize, beta=.98):
super().__init__()
self.start_lr = start_lr
self.end_lr = end_lr
self.stepsize = stepsize
self.beta = beta
self.lr_mult = (end_lr/start_lr)**(1/stepsize)
def on_train_begin(self, logs=None):
self.best_loss = 1e9
self.avg_loss = 0
self.losses, self.smoothed_losses, self.lrs, self.iterations = [], [], [], []
self.iteration = 0
logs = logs or {}
K.set_value(self.model.optimizer.lr, self.start_lr)
def on_batch_end(self, epoch, logs=None):
logs = logs or {}
loss = logs.get('loss')
self.iteration += 1
self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * loss
smoothed_loss = self.avg_loss / (1 - self.beta**self.iteration)
# Check if the loss is not exploding
if self.iteration>1 and smoothed_loss > self.best_loss * 4:
self.model.stop_training = True
return
if smoothed_loss < self.best_loss or self.iteration==1:
self.best_loss = smoothed_loss
lr = self.start_lr * (self.lr_mult**self.iteration)
self.losses.append(loss)
self.smoothed_losses.append(smoothed_loss)
self.lrs.append(lr)
self.iterations.append(self.iteration)
K.set_value(self.model.optimizer.lr, lr)
def plot(self, lskip=10, rskip=10):
lrs = self.lrs[lskip:-rskip]
losses = self.smoothed_losses[lskip:-rskip]
output_notebook()
p = figure(title='Learning Rate Finder', x_axis_label='LR', y_axis_label='Loss')
p.line(lrs, losses)
p.add_tools(
HoverTool(
show_arrow=False,
line_policy='next',
tooltips=[('LR', '$data_x'), ('Loss', '$data_y')]
)
)
show(p)
best_idxs = np.argpartition(losses, 15)[:15]
best_lrs = np.take(lrs, best_idxs)
print(f"Best LRs: {best_lrs}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment