Skip to content

Instantly share code, notes, and snippets.

@vgoklani
Forked from andrewnc/lr_scheduler.py
Created December 27, 2024 02:37
Show Gist options
  • Save vgoklani/b0127a08bfe4cdda7e47340b94be890c to your computer and use it in GitHub Desktop.
Save vgoklani/b0127a08bfe4cdda7e47340b94be890c to your computer and use it in GitHub Desktop.
God's Chosen Schedule
import math
import torch
from torch.optim.lr_scheduler import _LRScheduler
from dataclasses import dataclass
from typing import List
@dataclass
class SchedulePhase:
"""Defines a phase in the learning rate schedule"""
percent: float # Percentage of total steps this phase covers
lr_type: str # 'linear_warmup', 'constant', 'cosine_decay'
start_lr: float
end_lr: float = None # Only needed for linear_warmup and cosine_decay
class StepLRScheduler(_LRScheduler):
def __init__(self,
optimizer,
total_steps: int,
phases: List[SchedulePhase],
last_epoch: int = -1):
"""
A flexible learning rate scheduler that works based on percentage of total steps.
Args:
optimizer: PyTorch optimizer
total_steps: Total number of training steps
phases: List of SchedulePhase objects defining the schedule
last_epoch: The index of last epoch
"""
self.total_steps = total_steps
self.phases = phases
# Validate that percentages sum to approximately 100
total_percent = sum(phase.percent for phase in phases)
if not math.isclose(total_percent, 100.0, rel_tol=1e-3):
raise ValueError(f"Phase percentages must sum to approximately 100, got {total_percent}")
# Calculate step boundaries for each phase
self.phase_boundaries = []
current_steps = 0
for phase in phases:
phase_steps = max(int((phase.percent / 100.0) * total_steps), 1)
current_steps += phase_steps
self.phase_boundaries.append(current_steps)
# Adjust final boundary to match total_steps exactly
self.phase_boundaries[-1] = total_steps
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
"""Calculate learning rate based on current step"""
step = self.last_epoch
# Find current phase
current_phase_idx = 0
for idx, boundary in enumerate(self.phase_boundaries):
if step <= boundary:
current_phase_idx = idx
break
# Calculate start step and length of current phase
start_step = 0 if current_phase_idx == 0 else self.phase_boundaries[current_phase_idx - 1]
phase_length = self.phase_boundaries[current_phase_idx] - start_step
# Calculate progress through current phase
progress = (step - start_step) / max(phase_length, 1) # Avoid division by zero
phase = self.phases[current_phase_idx]
if phase.lr_type == 'linear_warmup':
lr = phase.start_lr + progress * (phase.end_lr - phase.start_lr)
elif phase.lr_type == 'constant':
lr = phase.start_lr
elif phase.lr_type == 'cosine_decay':
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
lr = phase.end_lr + (phase.start_lr - phase.end_lr) * cosine_decay
else:
raise ValueError(f"Unknown lr_type: {phase.lr_type}")
return [lr for _ in self.base_lrs]
# Example usage:
def create_scheduler(optimizer, total_steps: int):
"""
Creates a scheduler with a similar structure to the original example.
"""
phases = [
SchedulePhase(
percent=5, # Small percentage for warmup
lr_type='linear_warmup',
start_lr=0.0,
end_lr=2.2e-4
),
SchedulePhase(
percent=62.6, # Main training phase
lr_type='constant',
start_lr=2.2e-4
),
SchedulePhase(
percent=29.1, # Decay phase
lr_type='cosine_decay',
start_lr=2.2e-4,
end_lr=2.2e-5
),
SchedulePhase(
percent=2.2, # First final phase
lr_type='constant',
start_lr=2.2e-5
),
SchedulePhase(
percent=1.1, # Second final phase
lr_type='constant',
start_lr=7.3e-6
)
]
return StepLRScheduler(
optimizer,
total_steps=total_steps,
phases=phases,
)
# Example setup and visualization
if __name__ == "__main__":
model = torch.nn.Linear(10, 10) # Dummy model
optimizer = torch.optim.AdamW(model.parameters(), lr=2.2e-4)
# Example with 10000 total steps
total_steps = 10000
scheduler = create_scheduler(optimizer, total_steps)
# Visualization code
import matplotlib.pyplot as plt
import numpy as np
steps = np.linspace(0, scheduler.total_steps, 1000)
lrs = []
for step in steps:
scheduler.last_epoch = int(step)
lrs.append(scheduler.get_lr()[0])
plt.figure(figsize=(12, 6))
plt.plot(steps, lrs)
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.yscale('log')
plt.title('Learning Rate Schedule')
plt.grid(True)
plt.savefig('lr_schedule.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment