-
-
Save jeremyjordan/5a222e04bb78c242f5763ad40626c452 to your computer and use it in GitHub Desktop.
from keras.callbacks import Callback | |
import keras.backend as K | |
import numpy as np | |
class SGDRScheduler(Callback): | |
'''Cosine annealing learning rate scheduler with periodic restarts. | |
# Usage | |
```python | |
schedule = SGDRScheduler(min_lr=1e-5, | |
max_lr=1e-2, | |
steps_per_epoch=np.ceil(epoch_size/batch_size), | |
lr_decay=0.9, | |
cycle_length=5, | |
mult_factor=1.5) | |
model.fit(X_train, Y_train, epochs=100, callbacks=[schedule]) | |
``` | |
# Arguments | |
min_lr: The lower bound of the learning rate range for the experiment. | |
max_lr: The upper bound of the learning rate range for the experiment. | |
steps_per_epoch: Number of mini-batches in the dataset. Calculated as `np.ceil(epoch_size/batch_size)`. | |
lr_decay: Reduce the max_lr after the completion of each cycle. | |
Ex. To reduce the max_lr by 20% after each cycle, set this value to 0.8. | |
cycle_length: Initial number of epochs in a cycle. | |
mult_factor: Scale epochs_to_restart after each full cycle completion. | |
# References | |
Blog post: jeremyjordan.me/nn-learning-rate | |
Original paper: http://arxiv.org/abs/1608.03983 | |
''' | |
def __init__(self, | |
min_lr, | |
max_lr, | |
steps_per_epoch, | |
lr_decay=1, | |
cycle_length=10, | |
mult_factor=2): | |
self.min_lr = min_lr | |
self.max_lr = max_lr | |
self.lr_decay = lr_decay | |
self.batch_since_restart = 0 | |
self.next_restart = cycle_length | |
self.steps_per_epoch = steps_per_epoch | |
self.cycle_length = cycle_length | |
self.mult_factor = mult_factor | |
self.history = {} | |
def clr(self): | |
'''Calculate the learning rate.''' | |
fraction_to_restart = self.batch_since_restart / (self.steps_per_epoch * self.cycle_length) | |
lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + np.cos(fraction_to_restart * np.pi)) | |
return lr | |
def on_train_begin(self, logs={}): | |
'''Initialize the learning rate to the minimum value at the start of training.''' | |
logs = logs or {} | |
K.set_value(self.model.optimizer.lr, self.max_lr) | |
def on_batch_end(self, batch, logs={}): | |
'''Record previous batch statistics and update the learning rate.''' | |
logs = logs or {} | |
self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr)) | |
for k, v in logs.items(): | |
self.history.setdefault(k, []).append(v) | |
self.batch_since_restart += 1 | |
K.set_value(self.model.optimizer.lr, self.clr()) | |
def on_epoch_end(self, epoch, logs={}): | |
'''Check for end of current cycle, apply restarts when necessary.''' | |
if epoch + 1 == self.next_restart: | |
self.batch_since_restart = 0 | |
self.cycle_length = np.ceil(self.cycle_length * self.mult_factor) | |
self.next_restart += self.cycle_length | |
self.max_lr *= self.lr_decay | |
self.best_weights = self.model.get_weights() | |
def on_train_end(self, logs={}): | |
'''Set weights to the values from the end of the most recent cycle for best performance.''' | |
self.model.set_weights(self.best_weights) |
I would recommend adding logging to the callbacks to diagnose your problem.
def on_epoch_end(self, epoch, logs={}):
'''Check for end of current cycle, apply restarts when necessary.'''
logger.debug('epoch end, checking to see if end of cycle...')
if epoch + 1 == self.next_restart:
logger.debug('cycle finished, saving weights...')
self.batch_since_restart = 0
self.cycle_length = np.ceil(self.cycle_length * self.mult_factor)
self.next_restart += self.cycle_length
self.max_lr *= self.lr_decay
self.best_weights = self.model.get_weights()
def on_train_end(self, logs={}):
'''Set weights to the values from the end of the most recent cycle for best performance.'''
logger.debug('finished training, reloading weights from end of last cycle...')
self.model.set_weights(self.best_weights)
Let me know if you're able to figure it out!
Hi @jeremyjordan,
Thanks for the nice work! What is the license under which you publish this code ?
See here for an "official" implementation without license issues: https://www.tensorflow.org/api_docs/python/tf/keras/experimental/CosineDecayRestarts
@jeremyjordan Sorry for coming back to this after so long, but I recently noticed something with the code I am wondering you could explain?
Lets say I have a model that I have trained to 89 epochs previously. I then restart training at a later date and wish to train for another 11 epoch, up to 100. In this case, it seems as though you will never hit a cycle restart unless you specify a cycle length > 89, due to this line:
if epoch + 1 == self.next_restart:
However, with such a high cycle length, you may never again hit another restart.
For example:
number_epochs = 100
if not os.path.exists(model.log_dir):
os.makedirs(model.log_dir)
csv_logger = CSVLogger(os.path.join(model.log_dir, "epoch_logger.csv"), append=True)
# Save the model config to the log directory
print(config_list, file=open(os.path.join(model.log_dir,"config.txt"), 'w'))
# Train the head branches
# Passing layers="heads" freezes all layers except the head
# layers. You can also pass a regular expression to select
# which layers to train by name pattern.
# Cycle length is what you want + epoch restart num
# (e.g if you want a cycle length of 2, restart epoch is 89, thus cycle length = 91)
schedule = SGDRScheduler(min_lr=1e-5,
max_lr=1e-2,
steps_per_epoch=np.ceil(number_epochs/config.BATCH_SIZE),
lr_decay=0.9,
cycle_length=91,
mult_factor=1.5)
model.train(dataset_train, dataset_val,
learning_rate=1e-2,
epochs=number_epochs,
layers='heads',best_only=True, custom_callbacks = [schedule, csv_logger], augmentation= aug1)
Gives:
Starting at epoch 89. LR=0.01
Checkpoint Path: /home/b3020111/dolphin-recognition/Mask_RCNN-master/logs/ndd/above/od/ndd-above-od-1.10-tester20200324T1218/mask_rcnn_ndd-above-od-1.10-tester_{epoch:04d}.h5
Selecting layers to train
fpn_c5p5 (Conv2D)
fpn_c4p4 (Conv2D)
fpn_c3p3 (Conv2D)
fpn_c2p2 (Conv2D)
fpn_p5 (Conv2D)
fpn_p2 (Conv2D)
fpn_p3 (Conv2D)
fpn_p4 (Conv2D)
In model: rpn_model
rpn_conv_shared (Conv2D)
rpn_class_raw (Conv2D)
rpn_bbox_pred (Conv2D)
mrcnn_mask_conv1 (TimeDistributed)
mrcnn_mask_bn1 (TimeDistributed)
mrcnn_mask_conv2 (TimeDistributed)
mrcnn_mask_bn2 (TimeDistributed)
mrcnn_class_conv1 (TimeDistributed)
mrcnn_class_bn1 (TimeDistributed)
mrcnn_mask_conv3 (TimeDistributed)
mrcnn_mask_bn3 (TimeDistributed)
mrcnn_class_conv2 (TimeDistributed)
mrcnn_class_bn2 (TimeDistributed)
mrcnn_mask_conv4 (TimeDistributed)
mrcnn_mask_bn4 (TimeDistributed)
mrcnn_bbox_fc (TimeDistributed)
mrcnn_mask_deconv (TimeDistributed)
mrcnn_class_logits (TimeDistributed)
mrcnn_mask (TimeDistributed)
Epoch 90/100
99/100 [============================>.] - ETA: 3s - loss: 0.9831 - rpn_class_loss: 0.0082 - rpn_bbox_loss: 0.6570 - mrcnn_class_loss: 0.0171 - mrcnn_bbox_loss: 0.1787 - mrcnn_mask_loss: 0.1214
epoch end, checking to see if end of cycle...
epoch + 1 = 90 self.next_restart = 91
100/100 [==============================] - 378s 4s/step - loss: 0.9816 - rpn_class_loss: 0.0083 - rpn_bbox_loss: 0.6560 - mrcnn_class_loss: 0.0171 - mrcnn_bbox_loss: 0.1783 - mrcnn_mask_loss: 0.1213 - val_loss: 1.7121 - val_rpn_class_loss: 0.0104 - val_rpn_bbox_loss: 1.4068 - val_mrcnn_class_loss: 0.0109 - val_mrcnn_bbox_loss: 0.1635 - val_mrcnn_mask_loss: 0.1200
Epoch 91/100
99/100 [============================>.] - ETA: 2s - loss: 1.0819 - rpn_class_loss: 0.0092 - rpn_bbox_loss: 0.7405 - mrcnn_class_loss: 0.0164 - mrcnn_bbox_loss: 0.1952 - mrcnn_mask_loss: 0.1199
epoch end, checking to see if end of cycle...
epoch + 1 = 91 self.next_restart = 91
cycle finished, saving weights...
Next restart at 228.0
100/100 [==============================] - 300s 3s/step - loss: 1.0881 - rpn_class_loss: 0.0092 - rpn_bbox_loss: 0.7459 - mrcnn_class_loss: 0.0164 - mrcnn_bbox_loss: 0.1957 - mrcnn_mask_loss: 0.1202 - val_loss: 1.2089 - val_rpn_class_loss: 0.0113 - val_rpn_bbox_loss: 0.8753 - val_mrcnn_class_loss: 0.0101 - val_mrcnn_bbox_loss: 0.1912 - val_mrcnn_mask_loss: 0.1204
Epoch 92/100
99/100 [============================>.] - ETA: 2s - loss: 1.0154 - rpn_class_loss: 0.0100 - rpn_bbox_loss: 0.6705 - mrcnn_class_loss: 0.0160 - mrcnn_bbox_loss: 0.1912 - mrcnn_mask_loss: 0.1269
epoch end, checking to see if end of cycle...
epoch + 1 = 92 self.next_restart = 228.0
If the cycle length is changed from 91 to say, 2, then the epoch + 1 check is never fulfilled, and thus a best_weights is never stored causing the code to error at the end of the training.
Am I correct in this interpretation of how the cycle length is working in this code, and if so, is there a way to allow for smaller cycle lengths while still allowing for the restarting of training?
@jeremyjordan Thanks for the help. If I was training model heads first, and then finetuning all layers after, what would be the cycle length to use?
For example, if I have the following code:
Where the first scheduler sets up training for the head layers for 2 epochs, then the second scheduler sets up training all layers for 1 epoch, I still run into the best_weights error and am unsure why?
Output: