Skip to content

Instantly share code, notes, and snippets.

@jeremyjordan
Last active December 4, 2023 13:41
Show Gist options
  • Save jeremyjordan/5a222e04bb78c242f5763ad40626c452 to your computer and use it in GitHub Desktop.
Save jeremyjordan/5a222e04bb78c242f5763ad40626c452 to your computer and use it in GitHub Desktop.
Keras Callback for implementing Stochastic Gradient Descent with Restarts
from keras.callbacks import Callback
import keras.backend as K
import numpy as np
class SGDRScheduler(Callback):
'''Cosine annealing learning rate scheduler with periodic restarts.
# Usage
```python
schedule = SGDRScheduler(min_lr=1e-5,
max_lr=1e-2,
steps_per_epoch=np.ceil(epoch_size/batch_size),
lr_decay=0.9,
cycle_length=5,
mult_factor=1.5)
model.fit(X_train, Y_train, epochs=100, callbacks=[schedule])
```
# Arguments
min_lr: The lower bound of the learning rate range for the experiment.
max_lr: The upper bound of the learning rate range for the experiment.
steps_per_epoch: Number of mini-batches in the dataset. Calculated as `np.ceil(epoch_size/batch_size)`.
lr_decay: Reduce the max_lr after the completion of each cycle.
Ex. To reduce the max_lr by 20% after each cycle, set this value to 0.8.
cycle_length: Initial number of epochs in a cycle.
mult_factor: Scale epochs_to_restart after each full cycle completion.
# References
Blog post: jeremyjordan.me/nn-learning-rate
Original paper: http://arxiv.org/abs/1608.03983
'''
def __init__(self,
min_lr,
max_lr,
steps_per_epoch,
lr_decay=1,
cycle_length=10,
mult_factor=2):
self.min_lr = min_lr
self.max_lr = max_lr
self.lr_decay = lr_decay
self.batch_since_restart = 0
self.next_restart = cycle_length
self.steps_per_epoch = steps_per_epoch
self.cycle_length = cycle_length
self.mult_factor = mult_factor
self.history = {}
def clr(self):
'''Calculate the learning rate.'''
fraction_to_restart = self.batch_since_restart / (self.steps_per_epoch * self.cycle_length)
lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + np.cos(fraction_to_restart * np.pi))
return lr
def on_train_begin(self, logs={}):
'''Initialize the learning rate to the minimum value at the start of training.'''
logs = logs or {}
K.set_value(self.model.optimizer.lr, self.max_lr)
def on_batch_end(self, batch, logs={}):
'''Record previous batch statistics and update the learning rate.'''
logs = logs or {}
self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr))
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
self.batch_since_restart += 1
K.set_value(self.model.optimizer.lr, self.clr())
def on_epoch_end(self, epoch, logs={}):
'''Check for end of current cycle, apply restarts when necessary.'''
if epoch + 1 == self.next_restart:
self.batch_since_restart = 0
self.cycle_length = np.ceil(self.cycle_length * self.mult_factor)
self.next_restart += self.cycle_length
self.max_lr *= self.lr_decay
self.best_weights = self.model.get_weights()
def on_train_end(self, logs={}):
'''Set weights to the values from the end of the most recent cycle for best performance.'''
self.model.set_weights(self.best_weights)
@Trotts
Copy link

Trotts commented Sep 3, 2019

@jeremyjordan Thanks for the help. If I was training model heads first, and then finetuning all layers after, what would be the cycle length to use?

For example, if I have the following code:

schedule = SGDRScheduler(min_lr=1e-5,
                        max_lr=1e-2,
                        steps_per_epoch=np.ceil(100/config.BATCH_SIZE),
                        lr_decay=0.9,
                        cycle_length=2,
                        mult_factor=1.5)
model.train(dataset_train, dataset_val, 
                        learning_rate=1e-2, 
                        epochs=2, 
                        layers='heads',best_only=True, custom_callbacks = [schedule, csv_logger], augmentation=c[5])
schedule = SGDRScheduler(min_lr=1e-5/10,
                        max_lr=1e-2/10,
                        steps_per_epoch=np.ceil(100/config.BATCH_SIZE),
                        lr_decay=0.9,
                        cycle_length=1,
                        mult_factor=1.5)
 model.train(dataset_train, dataset_val, 
                        learning_rate=1e-2/10, 
                        epochs=3, 
                        layers='all',best_only=True, custom_callbacks = [schedule, csv_logger], augmentation=c[5])

Where the first scheduler sets up training for the head layers for 2 epochs, then the second scheduler sets up training all layers for 1 epoch, I still run into the best_weights error and am unsure why?

Output:

Selecting layers to train
fpn_c5p5               (Conv2D)
fpn_c4p4               (Conv2D)
fpn_c3p3               (Conv2D)
fpn_c2p2               (Conv2D)
fpn_p5                 (Conv2D)
fpn_p2                 (Conv2D)
fpn_p3                 (Conv2D)
fpn_p4                 (Conv2D)
In model:  rpn_model
    rpn_conv_shared        (Conv2D)
    rpn_class_raw          (Conv2D)
    rpn_bbox_pred          (Conv2D)
mrcnn_mask_conv1       (TimeDistributed)
mrcnn_mask_bn1         (TimeDistributed)
mrcnn_mask_conv2       (TimeDistributed)
mrcnn_mask_bn2         (TimeDistributed)
mrcnn_class_conv1      (TimeDistributed)
mrcnn_class_bn1        (TimeDistributed)
mrcnn_mask_conv3       (TimeDistributed)
mrcnn_mask_bn3         (TimeDistributed)
mrcnn_class_conv2      (TimeDistributed)
mrcnn_class_bn2        (TimeDistributed)
mrcnn_mask_conv4       (TimeDistributed)
mrcnn_mask_bn4         (TimeDistributed)
mrcnn_bbox_fc          (TimeDistributed)
mrcnn_mask_deconv      (TimeDistributed)
mrcnn_class_logits     (TimeDistributed)
mrcnn_mask             (TimeDistributed)


Epoch 1/2
100/100 [==============================] - 503s 5s/step - loss: 1.1976 - rpn_class_loss: 0.0139 - rpn_bbox_loss: 0.4676 - mrcnn_class_loss: 0.0649 - mrcnn_bbox_loss: 0.4371 - mrcnn_mask_loss: 0.2139 - val_loss: 0.9937 - val_rpn_class_loss: 0.0084 - val_rpn_bbox_loss: 0.4431 - val_mrcnn_class_loss: 0.0509 - val_mrcnn_bbox_loss: 0.3357 - val_mrcnn_mask_loss: 0.1553
Epoch 2/2
100/100 [==============================] - 355s 4s/step - loss: 0.7547 - rpn_class_loss: 0.0053 - rpn_bbox_loss: 0.3258 - mrcnn_class_loss: 0.0556 - mrcnn_bbox_loss: 0.2167 - mrcnn_mask_loss: 0.1511 - val_loss: 0.8487 - val_rpn_class_loss: 0.0063 - val_rpn_bbox_loss: 0.4210 - val_mrcnn_class_loss: 0.0482 - val_mrcnn_bbox_loss: 0.2187 - val_mrcnn_mask_loss: 0.1544
Saving best model only...

Starting at epoch 2. LR=0.001

Checkpoint Path: XXX
Selecting layers to train
conv1                  (Conv2D)
bn_conv1               (BatchNorm)
res2a_branch2a         (Conv2D)
bn2a_branch2a          (BatchNorm)
res2a_branch2b         (Conv2D)
bn2a_branch2b          (BatchNorm)
res2a_branch2c         (Conv2D)
res2a_branch1          (Conv2D)
bn2a_branch2c          (BatchNorm)
bn2a_branch1           (BatchNorm)
res2b_branch2a         (Conv2D)
bn2b_branch2a          (BatchNorm)
res2b_branch2b         (Conv2D)
bn2b_branch2b          (BatchNorm)
res2b_branch2c         (Conv2D)
bn2b_branch2c          (BatchNorm)
res2c_branch2a         (Conv2D)
bn2c_branch2a          (BatchNorm)
res2c_branch2b         (Conv2D)
bn2c_branch2b          (BatchNorm)
res2c_branch2c         (Conv2D)
bn2c_branch2c          (BatchNorm)
res3a_branch2a         (Conv2D)
bn3a_branch2a          (BatchNorm)
res3a_branch2b         (Conv2D)
bn3a_branch2b          (BatchNorm)
res3a_branch2c         (Conv2D)
res3a_branch1          (Conv2D)
bn3a_branch2c          (BatchNorm)
bn3a_branch1           (BatchNorm)
res3b_branch2a         (Conv2D)
bn3b_branch2a          (BatchNorm)
res3b_branch2b         (Conv2D)
bn3b_branch2b          (BatchNorm)
res3b_branch2c         (Conv2D)
bn3b_branch2c          (BatchNorm)
res3c_branch2a         (Conv2D)
bn3c_branch2a          (BatchNorm)
res3c_branch2b         (Conv2D)
bn3c_branch2b          (BatchNorm)
res3c_branch2c         (Conv2D)
bn3c_branch2c          (BatchNorm)
res3d_branch2a         (Conv2D)
bn3d_branch2a          (BatchNorm)
res3d_branch2b         (Conv2D)
bn3d_branch2b          (BatchNorm)
res3d_branch2c         (Conv2D)
bn3d_branch2c          (BatchNorm)
res4a_branch2a         (Conv2D)
bn4a_branch2a          (BatchNorm)
res4a_branch2b         (Conv2D)
bn4a_branch2b          (BatchNorm)
res4a_branch2c         (Conv2D)
res4a_branch1          (Conv2D)
bn4a_branch2c          (BatchNorm)
bn4a_branch1           (BatchNorm)
res4b_branch2a         (Conv2D)
bn4b_branch2a          (BatchNorm)
res4b_branch2b         (Conv2D)
bn4b_branch2b          (BatchNorm)
res4b_branch2c         (Conv2D)
bn4b_branch2c          (BatchNorm)
res4c_branch2a         (Conv2D)
bn4c_branch2a          (BatchNorm)
res4c_branch2b         (Conv2D)
bn4c_branch2b          (BatchNorm)
res4c_branch2c         (Conv2D)
bn4c_branch2c          (BatchNorm)
res4d_branch2a         (Conv2D)
bn4d_branch2a          (BatchNorm)
res4d_branch2b         (Conv2D)
bn4d_branch2b          (BatchNorm)
res4d_branch2c         (Conv2D)
bn4d_branch2c          (BatchNorm)
res4e_branch2a         (Conv2D)
bn4e_branch2a          (BatchNorm)
res4e_branch2b         (Conv2D)
bn4e_branch2b          (BatchNorm)
res4e_branch2c         (Conv2D)
bn4e_branch2c          (BatchNorm)
res4f_branch2a         (Conv2D)
bn4f_branch2a          (BatchNorm)
res4f_branch2b         (Conv2D)
bn4f_branch2b          (BatchNorm)
res4f_branch2c         (Conv2D)
bn4f_branch2c          (BatchNorm)
res5a_branch2a         (Conv2D)
bn5a_branch2a          (BatchNorm)
res5a_branch2b         (Conv2D)
bn5a_branch2b          (BatchNorm)
res5a_branch2c         (Conv2D)
res5a_branch1          (Conv2D)
bn5a_branch2c          (BatchNorm)
bn5a_branch1           (BatchNorm)
res5b_branch2a         (Conv2D)
bn5b_branch2a          (BatchNorm)
res5b_branch2b         (Conv2D)
bn5b_branch2b          (BatchNorm)
res5b_branch2c         (Conv2D)
bn5b_branch2c          (BatchNorm)
res5c_branch2a         (Conv2D)
bn5c_branch2a          (BatchNorm)
res5c_branch2b         (Conv2D)
bn5c_branch2b          (BatchNorm)
res5c_branch2c         (Conv2D)
bn5c_branch2c          (BatchNorm)
fpn_c5p5               (Conv2D)
fpn_c4p4               (Conv2D)
fpn_c3p3               (Conv2D)
fpn_c2p2               (Conv2D)
fpn_p5                 (Conv2D)
fpn_p2                 (Conv2D)
fpn_p3                 (Conv2D)
fpn_p4                 (Conv2D)
In model:  rpn_model
    rpn_conv_shared        (Conv2D)
    rpn_class_raw          (Conv2D)
    rpn_bbox_pred          (Conv2D)
mrcnn_mask_conv1       (TimeDistributed)
mrcnn_mask_bn1         (TimeDistributed)
mrcnn_mask_conv2       (TimeDistributed)
mrcnn_mask_bn2         (TimeDistributed)
mrcnn_class_conv1      (TimeDistributed)
mrcnn_class_bn1        (TimeDistributed)
mrcnn_mask_conv3       (TimeDistributed)
mrcnn_mask_bn3         (TimeDistributed)
mrcnn_class_conv2      (TimeDistributed)
mrcnn_class_bn2        (TimeDistributed)
mrcnn_mask_conv4       (TimeDistributed)
mrcnn_mask_bn4         (TimeDistributed)
mrcnn_bbox_fc          (TimeDistributed)
mrcnn_mask_deconv      (TimeDistributed)
mrcnn_class_logits     (TimeDistributed)
mrcnn_mask             (TimeDistributed)
Epoch 3/3
100/100 [==============================] - 567s 6s/step - loss: 0.5076 - rpn_class_loss: 0.0035 - rpn_bbox_loss: 0.1999 - mrcnn_class_loss: 0.0333 - mrcnn_bbox_loss: 0.1361 - mrcnn_mask_loss: 0.1341 - val_loss: 0.7756 - val_rpn_class_loss: 0.0063 - val_rpn_bbox_loss: 0.4040 - val_mrcnn_class_loss: 0.0305 - val_mrcnn_bbox_loss: 0.1886 - val_mrcnn_mask_loss: 0.1455

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-5-462313418ed9> in <module>
    157                         learning_rate=1e-2/10,
    158                         epochs=3,
--> 159                         layers='all',best_only=True, custom_callbacks = [schedule, csv_logger], augmentation=c[5])
    160 
    161     # Define the InferenceConfig

~/dolphin-recognition/Mask_RCNN-master/mrcnn/model.py in train(self, train_dataset, val_dataset, learning_rate, epochs, layers, augmentation, custom_callbacks, no_augmentation_sources, best_only)
   2387             max_queue_size=100,
   2388             workers=workers,
-> 2389             use_multiprocessing=True,
   2390         )
   2391 

~/.virtualenvs/dolphin-detection/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

~/.virtualenvs/dolphin-detection/lib/python3.6/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   2224                     val_enqueuer.stop()
   2225 
-> 2226         callbacks.on_train_end()
   2227         return self.history
   2228 

~/.virtualenvs/dolphin-detection/lib/python3.6/site-packages/keras/callbacks.py in on_train_end(self, logs)
    137         logs = logs or {}
    138         for callback in self.callbacks:
--> 139             callback.on_train_end(logs)
    140 
    141     def __iter__(self):

<ipython-input-4-a9934d14c2aa> in on_train_end(self, logs)
     79     def on_train_end(self, logs={}):
     80         '''Set weights to the values from the end of the most recent cycle for best performance.'''
---> 81         self.model.set_weights(self.best_weights)

AttributeError: 'SGDRScheduler' object has no attribute 'best_weights'

@jeremyjordan
Copy link
Author

I would recommend adding logging to the callbacks to diagnose your problem.

    def on_epoch_end(self, epoch, logs={}):
        '''Check for end of current cycle, apply restarts when necessary.'''
        logger.debug('epoch end, checking to see if end of cycle...')
        if epoch + 1 == self.next_restart:
            logger.debug('cycle finished, saving weights...')
            self.batch_since_restart = 0
            self.cycle_length = np.ceil(self.cycle_length * self.mult_factor)
            self.next_restart += self.cycle_length
            self.max_lr *= self.lr_decay
            self.best_weights = self.model.get_weights()

    def on_train_end(self, logs={}):
        '''Set weights to the values from the end of the most recent cycle for best performance.'''
        logger.debug('finished training, reloading weights from end of last cycle...')
        self.model.set_weights(self.best_weights)

Let me know if you're able to figure it out!

@pauqilaz
Copy link

Hi @jeremyjordan,
Thanks for the nice work! What is the license under which you publish this code ?

@PhilipMay
Copy link

PhilipMay commented May 3, 2020

See here for an "official" implementation without license issues: https://www.tensorflow.org/api_docs/python/tf/keras/experimental/CosineDecayRestarts

@Trotts
Copy link

Trotts commented Jun 26, 2020

@jeremyjordan Sorry for coming back to this after so long, but I recently noticed something with the code I am wondering you could explain?

Lets say I have a model that I have trained to 89 epochs previously. I then restart training at a later date and wish to train for another 11 epoch, up to 100. In this case, it seems as though you will never hit a cycle restart unless you specify a cycle length > 89, due to this line:

if epoch + 1 == self.next_restart:

However, with such a high cycle length, you may never again hit another restart.

For example:

number_epochs = 100

if not os.path.exists(model.log_dir):
    os.makedirs(model.log_dir)
    
csv_logger = CSVLogger(os.path.join(model.log_dir, "epoch_logger.csv"), append=True)

# Save the model config to the log directory
print(config_list, file=open(os.path.join(model.log_dir,"config.txt"), 'w'))

# Train the head branches
# Passing layers="heads" freezes all layers except the head
# layers. You can also pass a regular expression to select
# which layers to train by name pattern.

# Cycle length is what you want + epoch restart num 
# (e.g if you want a cycle length of 2, restart epoch is 89, thus cycle length = 91)
schedule = SGDRScheduler(min_lr=1e-5,
                        max_lr=1e-2,
                        steps_per_epoch=np.ceil(number_epochs/config.BATCH_SIZE),
                        lr_decay=0.9,
                        cycle_length=91, 
                        mult_factor=1.5)

model.train(dataset_train, dataset_val, 
                learning_rate=1e-2, 
                epochs=number_epochs, 
                layers='heads',best_only=True, custom_callbacks = [schedule, csv_logger], augmentation= aug1)

Gives:

Starting at epoch 89. LR=0.01

Checkpoint Path: /home/b3020111/dolphin-recognition/Mask_RCNN-master/logs/ndd/above/od/ndd-above-od-1.10-tester20200324T1218/mask_rcnn_ndd-above-od-1.10-tester_{epoch:04d}.h5
Selecting layers to train
fpn_c5p5               (Conv2D)
fpn_c4p4               (Conv2D)
fpn_c3p3               (Conv2D)
fpn_c2p2               (Conv2D)
fpn_p5                 (Conv2D)
fpn_p2                 (Conv2D)
fpn_p3                 (Conv2D)
fpn_p4                 (Conv2D)
In model:  rpn_model
    rpn_conv_shared        (Conv2D)
    rpn_class_raw          (Conv2D)
    rpn_bbox_pred          (Conv2D)
mrcnn_mask_conv1       (TimeDistributed)
mrcnn_mask_bn1         (TimeDistributed)
mrcnn_mask_conv2       (TimeDistributed)
mrcnn_mask_bn2         (TimeDistributed)
mrcnn_class_conv1      (TimeDistributed)
mrcnn_class_bn1        (TimeDistributed)
mrcnn_mask_conv3       (TimeDistributed)
mrcnn_mask_bn3         (TimeDistributed)
mrcnn_class_conv2      (TimeDistributed)
mrcnn_class_bn2        (TimeDistributed)
mrcnn_mask_conv4       (TimeDistributed)
mrcnn_mask_bn4         (TimeDistributed)
mrcnn_bbox_fc          (TimeDistributed)
mrcnn_mask_deconv      (TimeDistributed)
mrcnn_class_logits     (TimeDistributed)
mrcnn_mask             (TimeDistributed)
Epoch 90/100
 99/100 [============================>.] - ETA: 3s - loss: 0.9831 - rpn_class_loss: 0.0082 - rpn_bbox_loss: 0.6570 - mrcnn_class_loss: 0.0171 - mrcnn_bbox_loss: 0.1787 - mrcnn_mask_loss: 0.1214
 epoch end, checking to see if end of cycle...

 epoch + 1 =  90  self.next_restart =  91
100/100 [==============================] - 378s 4s/step - loss: 0.9816 - rpn_class_loss: 0.0083 - rpn_bbox_loss: 0.6560 - mrcnn_class_loss: 0.0171 - mrcnn_bbox_loss: 0.1783 - mrcnn_mask_loss: 0.1213 - val_loss: 1.7121 - val_rpn_class_loss: 0.0104 - val_rpn_bbox_loss: 1.4068 - val_mrcnn_class_loss: 0.0109 - val_mrcnn_bbox_loss: 0.1635 - val_mrcnn_mask_loss: 0.1200
Epoch 91/100
 99/100 [============================>.] - ETA: 2s - loss: 1.0819 - rpn_class_loss: 0.0092 - rpn_bbox_loss: 0.7405 - mrcnn_class_loss: 0.0164 - mrcnn_bbox_loss: 0.1952 - mrcnn_mask_loss: 0.1199
 epoch end, checking to see if end of cycle...

 epoch + 1 =  91  self.next_restart =  91

 cycle finished, saving weights...
Next restart at  228.0
100/100 [==============================] - 300s 3s/step - loss: 1.0881 - rpn_class_loss: 0.0092 - rpn_bbox_loss: 0.7459 - mrcnn_class_loss: 0.0164 - mrcnn_bbox_loss: 0.1957 - mrcnn_mask_loss: 0.1202 - val_loss: 1.2089 - val_rpn_class_loss: 0.0113 - val_rpn_bbox_loss: 0.8753 - val_mrcnn_class_loss: 0.0101 - val_mrcnn_bbox_loss: 0.1912 - val_mrcnn_mask_loss: 0.1204
Epoch 92/100
 99/100 [============================>.] - ETA: 2s - loss: 1.0154 - rpn_class_loss: 0.0100 - rpn_bbox_loss: 0.6705 - mrcnn_class_loss: 0.0160 - mrcnn_bbox_loss: 0.1912 - mrcnn_mask_loss: 0.1269
 epoch end, checking to see if end of cycle...

 epoch + 1 =  92  self.next_restart =  228.0

If the cycle length is changed from 91 to say, 2, then the epoch + 1 check is never fulfilled, and thus a best_weights is never stored causing the code to error at the end of the training.

Am I correct in this interpretation of how the cycle length is working in this code, and if so, is there a way to allow for smaller cycle lengths while still allowing for the restarting of training?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment