Last active
December 14, 2018 17:13
-
-
Save Mirodil/5340ac9950df0f3d52522f3ecc481aac to your computer and use it in GitHub Desktop.
LearningRateFinder for keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LearningRateFinder(Callback): | |
''' | |
This callback implements a learning rate finder(LRF) | |
The learning rate is constantly increased during training. | |
On training end, the training loss is plotted against the learning rate. | |
One may choose a learning rate for a model based on the given graph, | |
selecting a value slightly before the minimal training loss. | |
# Example | |
lrf = LearningRateFinder([0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05]) | |
model.fit(x_train, y_train, epochs=1, batch_size=128, callbacks=[lrf]) | |
# Arguments | |
lrs: list of learning rates | |
''' | |
def __init__(self, lrs): | |
self.index = 0 | |
self.learningRateList = lrs | |
self.losses = [] | |
self.lrs = [] | |
def on_epoch_end(self, epoch, logs={}): | |
lr = float(K.get_value(self.model.optimizer.lr)) | |
self.losses.append(logs.get('loss')) | |
self.lrs.append(lr) | |
def on_epoch_begin(self, epoch, logs=None): | |
if not hasattr(self.model.optimizer, 'lr'): | |
raise ValueError('Optimizer must have a "lr" attribute.') | |
lr = float(K.get_value(self.model.optimizer.lr)) | |
try: | |
lr = self.learningRateList[self.index] | |
self.model.set_weights(self.initial_weights) | |
self.index = self.index + 1 | |
if(self.index>=len(self.learningRateList)): | |
self.index = 0 | |
except TypeError: # old API for backward compatibility | |
lr = self.schedule(epoch) | |
if not isinstance(lr, (float, np.float32, np.float64)): | |
raise ValueError('The output of the "schedule" function ' | |
'should be float.') | |
K.set_value(self.model.optimizer.lr, lr) | |
print('\nEpoch %05d: LearningRateFinder changing learning rate to %s.' % (epoch + 1, lr)) | |
def on_train_end(self, logs=None): | |
plt.plot(self.lrs, self.losses) | |
plt.ylabel('losses') | |
plt.xlabel('lrs') | |
plt.show() | |
def on_train_begin(self, logs=None): | |
self.initial_weights = self.model.get_weights() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment