Last active
June 3, 2022 12:13
-
-
Save WittmannF/c55ed82d27248d18799e2be324a79473 to your computer and use it in GitHub Desktop.
Learning Rate Finder as a Keras Callback
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.callbacks import Callback | |
import keras.backend as K | |
import numpy as np | |
import matplotlib.pyplot as plt | |
class LRFinder(Callback): | |
""" | |
Up-to date version: https://github.com/WittmannF/LRFinder | |
Example of usage: | |
from keras.models import Sequential | |
from keras.layers import Flatten, Dense | |
from keras.datasets import fashion_mnist | |
!git clone https://github.com/WittmannF/LRFinder.git | |
from LRFinder.keras_callback import LRFinder | |
# 1. Input Data | |
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data() | |
mean, std = X_train.mean(), X_train.std() | |
X_train, X_test = (X_train-mean)/std, (X_test-mean)/std | |
# 2. Define and Compile Model | |
model = Sequential([Flatten(), | |
Dense(512, activation='relu'), | |
Dense(10, activation='softmax')]) | |
model.compile(loss='sparse_categorical_crossentropy', \ | |
metrics=['accuracy'], optimizer='sgd') | |
# 3. Fit using Callback | |
lr_finder = LRFinder(min_lr=1e-4, max_lr=1) | |
model.fit(X_train, y_train, batch_size=128, callbacks=[lr_finder], epochs=2) | |
""" | |
def __init__(self, min_lr, max_lr, mom=0.9, stop_multiplier=None, | |
reload_weights=True, batches_lr_update=5): | |
self.min_lr = min_lr | |
self.max_lr = max_lr | |
self.mom = mom | |
self.reload_weights = reload_weights | |
self.batches_lr_update = batches_lr_update | |
if stop_multiplier is None: | |
self.stop_multiplier = -20*self.mom/3 + 10 # 4 if mom=0.9 | |
# 10 if mom=0 | |
else: | |
self.stop_multiplier = stop_multiplier | |
def on_train_begin(self, logs={}): | |
p = self.params | |
try: | |
n_iterations = p['epochs']*p['samples']//p['batch_size'] | |
except: | |
n_iterations = p['steps']*p['epochs'] | |
self.learning_rates = np.geomspace(self.min_lr, self.max_lr, \ | |
num=n_iterations//self.batches_lr_update+1) | |
self.losses=[] | |
self.iteration=0 | |
self.best_loss=0 | |
if self.reload_weights: | |
self.model.save_weights('tmp.hdf5') | |
def on_batch_end(self, batch, logs={}): | |
loss = logs.get('loss') | |
if self.iteration!=0: # Make loss smoother using momentum | |
loss = self.losses[-1]*self.mom+loss*(1-self.mom) | |
if self.iteration==0 or loss < self.best_loss: | |
self.best_loss = loss | |
if self.iteration%self.batches_lr_update==0: # Evaluate each lr over 5 epochs | |
if self.reload_weights: | |
self.model.load_weights('tmp.hdf5') | |
lr = self.learning_rates[self.iteration//self.batches_lr_update] | |
K.set_value(self.model.optimizer.lr, lr) | |
self.losses.append(loss) | |
if loss > self.best_loss*self.stop_multiplier: # Stop criteria | |
self.model.stop_training = True | |
self.iteration += 1 | |
def on_train_end(self, logs=None): | |
if self.reload_weights: | |
self.model.load_weights('tmp.hdf5') | |
plt.figure(figsize=(12, 6)) | |
plt.plot(self.learning_rates[:len(self.losses)], self.losses) | |
plt.xlabel("Learning Rate") | |
plt.ylabel("Loss") | |
plt.xscale('log') | |
plt.show() |
Three main differences from existing implementations:
- Number of iterations is automatically inferred as the number of batches (i.e., it will always run over a full epoch)
- Set of learning rates are spaced evenly on a log scale (a geometric progression) using np.geospace
- Automatic stop criteria if current_loss > 10 x lowest_loss
There's a version with more updates here: https://github.com/WittmannF/LRFinder
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example:
Output: