Last active
June 3, 2022 12:13
-
-
Save WittmannF/c55ed82d27248d18799e2be324a79473 to your computer and use it in GitHub Desktop.
Learning Rate Finder as a Keras Callback
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.callbacks import Callback | |
import keras.backend as K | |
import numpy as np | |
import matplotlib.pyplot as plt | |
class LRFinder(Callback): | |
""" | |
Up-to date version: https://github.com/WittmannF/LRFinder | |
Example of usage: | |
from keras.models import Sequential | |
from keras.layers import Flatten, Dense | |
from keras.datasets import fashion_mnist | |
!git clone https://github.com/WittmannF/LRFinder.git | |
from LRFinder.keras_callback import LRFinder | |
# 1. Input Data | |
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data() | |
mean, std = X_train.mean(), X_train.std() | |
X_train, X_test = (X_train-mean)/std, (X_test-mean)/std | |
# 2. Define and Compile Model | |
model = Sequential([Flatten(), | |
Dense(512, activation='relu'), | |
Dense(10, activation='softmax')]) | |
model.compile(loss='sparse_categorical_crossentropy', \ | |
metrics=['accuracy'], optimizer='sgd') | |
# 3. Fit using Callback | |
lr_finder = LRFinder(min_lr=1e-4, max_lr=1) | |
model.fit(X_train, y_train, batch_size=128, callbacks=[lr_finder], epochs=2) | |
""" | |
def __init__(self, min_lr, max_lr, mom=0.9, stop_multiplier=None, | |
reload_weights=True, batches_lr_update=5): | |
self.min_lr = min_lr | |
self.max_lr = max_lr | |
self.mom = mom | |
self.reload_weights = reload_weights | |
self.batches_lr_update = batches_lr_update | |
if stop_multiplier is None: | |
self.stop_multiplier = -20*self.mom/3 + 10 # 4 if mom=0.9 | |
# 10 if mom=0 | |
else: | |
self.stop_multiplier = stop_multiplier | |
def on_train_begin(self, logs={}): | |
p = self.params | |
try: | |
n_iterations = p['epochs']*p['samples']//p['batch_size'] | |
except: | |
n_iterations = p['steps']*p['epochs'] | |
self.learning_rates = np.geomspace(self.min_lr, self.max_lr, \ | |
num=n_iterations//self.batches_lr_update+1) | |
self.losses=[] | |
self.iteration=0 | |
self.best_loss=0 | |
if self.reload_weights: | |
self.model.save_weights('tmp.hdf5') | |
def on_batch_end(self, batch, logs={}): | |
loss = logs.get('loss') | |
if self.iteration!=0: # Make loss smoother using momentum | |
loss = self.losses[-1]*self.mom+loss*(1-self.mom) | |
if self.iteration==0 or loss < self.best_loss: | |
self.best_loss = loss | |
if self.iteration%self.batches_lr_update==0: # Evaluate each lr over 5 epochs | |
if self.reload_weights: | |
self.model.load_weights('tmp.hdf5') | |
lr = self.learning_rates[self.iteration//self.batches_lr_update] | |
K.set_value(self.model.optimizer.lr, lr) | |
self.losses.append(loss) | |
if loss > self.best_loss*self.stop_multiplier: # Stop criteria | |
self.model.stop_training = True | |
self.iteration += 1 | |
def on_train_end(self, logs=None): | |
if self.reload_weights: | |
self.model.load_weights('tmp.hdf5') | |
plt.figure(figsize=(12, 6)) | |
plt.plot(self.learning_rates[:len(self.losses)], self.losses) | |
plt.xlabel("Learning Rate") | |
plt.ylabel("Loss") | |
plt.xscale('log') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
There's a version with more updates here: https://github.com/WittmannF/LRFinder