Last active
December 28, 2019 12:00
-
-
Save devxpy/b788661d5f5b548124ba3d6fb1813c5a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# 🙏 to prior art: http://puzzlemusa.com/2018/05/14/learning-rate-finder-using-keras/ | |
# | |
# Example: | |
# finder = find_lr(model, 1e-10, 1e-2) | |
# finder.plot() | |
from bokeh.plotting import figure, show | |
from bokeh.io import output_notebook | |
from bokeh.models import HoverTool | |
def find_lr(model, start_lr, end_lr, stepsize=len(train_df) // batch_size): | |
finder = LRFinder(start_lr, end_lr, stepsize) | |
weights = model.get_weights() | |
try: | |
history = model.fit_generator( | |
generator=train_gen, | |
validation_data=val_gen, | |
epochs=1, verbose=1, callbacks=[finder], | |
) | |
finally: | |
model.set_weights(weights) | |
return finder | |
class LRFinder(Callback): | |
def __init__(self, start_lr, end_lr, stepsize, beta=.98): | |
super().__init__() | |
self.start_lr = start_lr | |
self.end_lr = end_lr | |
self.stepsize = stepsize | |
self.beta = beta | |
self.lr_mult = (end_lr/start_lr)**(1/stepsize) | |
def on_train_begin(self, logs=None): | |
self.best_loss = 1e9 | |
self.avg_loss = 0 | |
self.losses, self.smoothed_losses, self.lrs, self.iterations = [], [], [], [] | |
self.iteration = 0 | |
logs = logs or {} | |
K.set_value(self.model.optimizer.lr, self.start_lr) | |
def on_batch_end(self, epoch, logs=None): | |
logs = logs or {} | |
loss = logs.get('loss') | |
self.iteration += 1 | |
self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * loss | |
smoothed_loss = self.avg_loss / (1 - self.beta**self.iteration) | |
# Check if the loss is not exploding | |
if self.iteration>1 and smoothed_loss > self.best_loss * 4: | |
self.model.stop_training = True | |
return | |
if smoothed_loss < self.best_loss or self.iteration==1: | |
self.best_loss = smoothed_loss | |
lr = self.start_lr * (self.lr_mult**self.iteration) | |
self.losses.append(loss) | |
self.smoothed_losses.append(smoothed_loss) | |
self.lrs.append(lr) | |
self.iterations.append(self.iteration) | |
K.set_value(self.model.optimizer.lr, lr) | |
def plot(self, lskip=10, rskip=10): | |
lrs = self.lrs[lskip:-rskip] | |
losses = self.smoothed_losses[lskip:-rskip] | |
output_notebook() | |
p = figure(title='Learning Rate Finder', x_axis_label='LR', y_axis_label='Loss') | |
p.line(lrs, losses) | |
p.add_tools( | |
HoverTool( | |
show_arrow=False, | |
line_policy='next', | |
tooltips=[('LR', '$data_x'), ('Loss', '$data_y')] | |
) | |
) | |
show(p) | |
best_idxs = np.argpartition(losses, 15)[:15] | |
best_lrs = np.take(lrs, best_idxs) | |
print(f"Best LRs: {best_lrs}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment