-
-
Save vgoklani/b0127a08bfe4cdda7e47340b94be890c to your computer and use it in GitHub Desktop.
God's Chosen Schedule
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
from torch.optim.lr_scheduler import _LRScheduler | |
from dataclasses import dataclass | |
from typing import List | |
@dataclass | |
class SchedulePhase: | |
"""Defines a phase in the learning rate schedule""" | |
percent: float # Percentage of total steps this phase covers | |
lr_type: str # 'linear_warmup', 'constant', 'cosine_decay' | |
start_lr: float | |
end_lr: float = None # Only needed for linear_warmup and cosine_decay | |
class StepLRScheduler(_LRScheduler): | |
def __init__(self, | |
optimizer, | |
total_steps: int, | |
phases: List[SchedulePhase], | |
last_epoch: int = -1): | |
""" | |
A flexible learning rate scheduler that works based on percentage of total steps. | |
Args: | |
optimizer: PyTorch optimizer | |
total_steps: Total number of training steps | |
phases: List of SchedulePhase objects defining the schedule | |
last_epoch: The index of last epoch | |
""" | |
self.total_steps = total_steps | |
self.phases = phases | |
# Validate that percentages sum to approximately 100 | |
total_percent = sum(phase.percent for phase in phases) | |
if not math.isclose(total_percent, 100.0, rel_tol=1e-3): | |
raise ValueError(f"Phase percentages must sum to approximately 100, got {total_percent}") | |
# Calculate step boundaries for each phase | |
self.phase_boundaries = [] | |
current_steps = 0 | |
for phase in phases: | |
phase_steps = max(int((phase.percent / 100.0) * total_steps), 1) | |
current_steps += phase_steps | |
self.phase_boundaries.append(current_steps) | |
# Adjust final boundary to match total_steps exactly | |
self.phase_boundaries[-1] = total_steps | |
super().__init__(optimizer, last_epoch) | |
def get_lr(self) -> List[float]: | |
"""Calculate learning rate based on current step""" | |
step = self.last_epoch | |
# Find current phase | |
current_phase_idx = 0 | |
for idx, boundary in enumerate(self.phase_boundaries): | |
if step <= boundary: | |
current_phase_idx = idx | |
break | |
# Calculate start step and length of current phase | |
start_step = 0 if current_phase_idx == 0 else self.phase_boundaries[current_phase_idx - 1] | |
phase_length = self.phase_boundaries[current_phase_idx] - start_step | |
# Calculate progress through current phase | |
progress = (step - start_step) / max(phase_length, 1) # Avoid division by zero | |
phase = self.phases[current_phase_idx] | |
if phase.lr_type == 'linear_warmup': | |
lr = phase.start_lr + progress * (phase.end_lr - phase.start_lr) | |
elif phase.lr_type == 'constant': | |
lr = phase.start_lr | |
elif phase.lr_type == 'cosine_decay': | |
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) | |
lr = phase.end_lr + (phase.start_lr - phase.end_lr) * cosine_decay | |
else: | |
raise ValueError(f"Unknown lr_type: {phase.lr_type}") | |
return [lr for _ in self.base_lrs] | |
# Example usage: | |
def create_scheduler(optimizer, total_steps: int): | |
""" | |
Creates a scheduler with a similar structure to the original example. | |
""" | |
phases = [ | |
SchedulePhase( | |
percent=5, # Small percentage for warmup | |
lr_type='linear_warmup', | |
start_lr=0.0, | |
end_lr=2.2e-4 | |
), | |
SchedulePhase( | |
percent=62.6, # Main training phase | |
lr_type='constant', | |
start_lr=2.2e-4 | |
), | |
SchedulePhase( | |
percent=29.1, # Decay phase | |
lr_type='cosine_decay', | |
start_lr=2.2e-4, | |
end_lr=2.2e-5 | |
), | |
SchedulePhase( | |
percent=2.2, # First final phase | |
lr_type='constant', | |
start_lr=2.2e-5 | |
), | |
SchedulePhase( | |
percent=1.1, # Second final phase | |
lr_type='constant', | |
start_lr=7.3e-6 | |
) | |
] | |
return StepLRScheduler( | |
optimizer, | |
total_steps=total_steps, | |
phases=phases, | |
) | |
# Example setup and visualization | |
if __name__ == "__main__": | |
model = torch.nn.Linear(10, 10) # Dummy model | |
optimizer = torch.optim.AdamW(model.parameters(), lr=2.2e-4) | |
# Example with 10000 total steps | |
total_steps = 10000 | |
scheduler = create_scheduler(optimizer, total_steps) | |
# Visualization code | |
import matplotlib.pyplot as plt | |
import numpy as np | |
steps = np.linspace(0, scheduler.total_steps, 1000) | |
lrs = [] | |
for step in steps: | |
scheduler.last_epoch = int(step) | |
lrs.append(scheduler.get_lr()[0]) | |
plt.figure(figsize=(12, 6)) | |
plt.plot(steps, lrs) | |
plt.xlabel('Steps') | |
plt.ylabel('Learning Rate') | |
plt.yscale('log') | |
plt.title('Learning Rate Schedule') | |
plt.grid(True) | |
plt.savefig('lr_schedule.png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment