Skip to content

Instantly share code, notes, and snippets.

@gsoykan
Last active March 20, 2023 13:05
Show Gist options
  • Save gsoykan/f8b4d44bca4dd8996f4f5e04ec48f5d0 to your computer and use it in GitHub Desktop.
Save gsoykan/f8b4d44bca4dd8996f4f5e04ec48f5d0 to your computer and use it in GitHub Desktop.
changing optimizer and scheduler during training
import pytorch_lightning as pl
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10, 5)
self.layer2 = nn.Linear(5, 1)
self.loss_fn = nn.MSELoss()
self.optimizer = Adam(self.parameters(), lr=1e-3)
self.lr_scheduler = StepLR(self.optimizer, step_size=10, gamma=0.1)
def forward(self, x):
x = self.layer1(x)
x = F.relu(x)
x = self.layer2(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return {'optimizer': self.optimizer, 'lr_scheduler': self.lr_scheduler}
def on_epoch_end(self):
if self.current_epoch == 20:
# Change optimizer and lr_scheduler after 20 epochs
self.optimizer = Adam(self.parameters(), lr=1e-4)
self.lr_scheduler = StepLR(self.optimizer, step_size=5, gamma=0.5)
self.trainer.optimizers[0] = self.optimizer
self.trainer.lr_schedulers[0]['scheduler'] = self.lr_scheduler
"""
another source: https://github.com/Lightning-AI/lightning/issues/3095
In PyTorch Lightning, you can change the optimizer and learning rate scheduler during training by using the built-in configure_optimizers() method in your LightningModule.
To change the optimizer, you can update the optimizer attribute of the TrainingStep object returned by configure_optimizers(). Similarly, to change the learning rate scheduler, you can update the lr_scheduler attribute of the TrainingStep object.
In the example above, we define a PyTorch Lightning module MyModel with an Adam optimizer and a StepLR scheduler. We define the configure_optimizers() method to return a dictionary with the optimizer and scheduler.
In the on_epoch_end() method, we change the optimizer and scheduler after 20 epochs. We update the optimizer and lr_scheduler attributes and then update the optimizer and scheduler in the trainer by accessing the trainer.optimizers and trainer.lr_schedulers lists. The index 0 is used to update the first optimizer and scheduler since we only have one in this example.
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment