Skip to content

Instantly share code, notes, and snippets.

@WittmannF
Last active June 3, 2022 12:13
Show Gist options
  • Save WittmannF/c55ed82d27248d18799e2be324a79473 to your computer and use it in GitHub Desktop.
Save WittmannF/c55ed82d27248d18799e2be324a79473 to your computer and use it in GitHub Desktop.
Learning Rate Finder as a Keras Callback
from keras.callbacks import Callback
import keras.backend as K
import numpy as np
import matplotlib.pyplot as plt
class LRFinder(Callback):
"""
Up-to date version: https://github.com/WittmannF/LRFinder
Example of usage:
from keras.models import Sequential
from keras.layers import Flatten, Dense
from keras.datasets import fashion_mnist
!git clone https://github.com/WittmannF/LRFinder.git
from LRFinder.keras_callback import LRFinder
# 1. Input Data
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
mean, std = X_train.mean(), X_train.std()
X_train, X_test = (X_train-mean)/std, (X_test-mean)/std
# 2. Define and Compile Model
model = Sequential([Flatten(),
Dense(512, activation='relu'),
Dense(10, activation='softmax')])
model.compile(loss='sparse_categorical_crossentropy', \
metrics=['accuracy'], optimizer='sgd')
# 3. Fit using Callback
lr_finder = LRFinder(min_lr=1e-4, max_lr=1)
model.fit(X_train, y_train, batch_size=128, callbacks=[lr_finder], epochs=2)
"""
def __init__(self, min_lr, max_lr, mom=0.9, stop_multiplier=None,
reload_weights=True, batches_lr_update=5):
self.min_lr = min_lr
self.max_lr = max_lr
self.mom = mom
self.reload_weights = reload_weights
self.batches_lr_update = batches_lr_update
if stop_multiplier is None:
self.stop_multiplier = -20*self.mom/3 + 10 # 4 if mom=0.9
# 10 if mom=0
else:
self.stop_multiplier = stop_multiplier
def on_train_begin(self, logs={}):
p = self.params
try:
n_iterations = p['epochs']*p['samples']//p['batch_size']
except:
n_iterations = p['steps']*p['epochs']
self.learning_rates = np.geomspace(self.min_lr, self.max_lr, \
num=n_iterations//self.batches_lr_update+1)
self.losses=[]
self.iteration=0
self.best_loss=0
if self.reload_weights:
self.model.save_weights('tmp.hdf5')
def on_batch_end(self, batch, logs={}):
loss = logs.get('loss')
if self.iteration!=0: # Make loss smoother using momentum
loss = self.losses[-1]*self.mom+loss*(1-self.mom)
if self.iteration==0 or loss < self.best_loss:
self.best_loss = loss
if self.iteration%self.batches_lr_update==0: # Evaluate each lr over 5 epochs
if self.reload_weights:
self.model.load_weights('tmp.hdf5')
lr = self.learning_rates[self.iteration//self.batches_lr_update]
K.set_value(self.model.optimizer.lr, lr)
self.losses.append(loss)
if loss > self.best_loss*self.stop_multiplier: # Stop criteria
self.model.stop_training = True
self.iteration += 1
def on_train_end(self, logs=None):
if self.reload_weights:
self.model.load_weights('tmp.hdf5')
plt.figure(figsize=(12, 6))
plt.plot(self.learning_rates[:len(self.losses)], self.losses)
plt.xlabel("Learning Rate")
plt.ylabel("Loss")
plt.xscale('log')
plt.show()
@WittmannF
Copy link
Author

There's a version with more updates here: https://github.com/WittmannF/LRFinder

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment