Last active
March 20, 2023 13:05
-
-
Save gsoykan/f8b4d44bca4dd8996f4f5e04ec48f5d0 to your computer and use it in GitHub Desktop.
changing optimizer and scheduler during training
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytorch_lightning as pl | |
from torch.optim import Adam | |
from torch.optim.lr_scheduler import StepLR | |
class MyModel(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.layer1 = nn.Linear(10, 5) | |
self.layer2 = nn.Linear(5, 1) | |
self.loss_fn = nn.MSELoss() | |
self.optimizer = Adam(self.parameters(), lr=1e-3) | |
self.lr_scheduler = StepLR(self.optimizer, step_size=10, gamma=0.1) | |
def forward(self, x): | |
x = self.layer1(x) | |
x = F.relu(x) | |
x = self.layer2(x) | |
return x | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
y_hat = self(x) | |
loss = self.loss_fn(y_hat, y) | |
self.log('train_loss', loss) | |
return loss | |
def configure_optimizers(self): | |
return {'optimizer': self.optimizer, 'lr_scheduler': self.lr_scheduler} | |
def on_epoch_end(self): | |
if self.current_epoch == 20: | |
# Change optimizer and lr_scheduler after 20 epochs | |
self.optimizer = Adam(self.parameters(), lr=1e-4) | |
self.lr_scheduler = StepLR(self.optimizer, step_size=5, gamma=0.5) | |
self.trainer.optimizers[0] = self.optimizer | |
self.trainer.lr_schedulers[0]['scheduler'] = self.lr_scheduler | |
""" | |
another source: https://github.com/Lightning-AI/lightning/issues/3095 | |
In PyTorch Lightning, you can change the optimizer and learning rate scheduler during training by using the built-in configure_optimizers() method in your LightningModule. | |
To change the optimizer, you can update the optimizer attribute of the TrainingStep object returned by configure_optimizers(). Similarly, to change the learning rate scheduler, you can update the lr_scheduler attribute of the TrainingStep object. | |
In the example above, we define a PyTorch Lightning module MyModel with an Adam optimizer and a StepLR scheduler. We define the configure_optimizers() method to return a dictionary with the optimizer and scheduler. | |
In the on_epoch_end() method, we change the optimizer and scheduler after 20 epochs. We update the optimizer and lr_scheduler attributes and then update the optimizer and scheduler in the trainer by accessing the trainer.optimizers and trainer.lr_schedulers lists. The index 0 is used to update the first optimizer and scheduler since we only have one in this example. | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment