Last active
March 7, 2025 09:33
-
-
Save chenyaofo/f0bbd3552f0ad72f76d546d17a6e34de to your computer and use it in GitHub Desktop.
Armijo Line Search
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import typing | |
from functools import partial | |
class BaseLineSearch: | |
def _backup_parameters(self, parameters: typing.List[torch.Tensor]): | |
self.original_params = [p.detach().clone() for p in parameters] | |
def _restore_parameters(self, parameters: typing.List[torch.Tensor]): | |
with torch.no_grad(): | |
for param, ori_param in zip(parameters, self.original_params): | |
param.data.copy_(ori_param) | |
def _apply_gradient_update(self, step_size, parameters: typing.List[torch.Tensor]): | |
with torch.no_grad(): | |
for param in parameters: | |
if param.grad is not None: | |
param.data.sub_(step_size * param.grad) | |
def _release_store_prameters(self): | |
self.original_params.clear() | |
def __call__(self, parameters: typing.List[torch.Tensor], loss_fn: callable) -> float: | |
return self.line_search(parameters, loss_fn) | |
def line_search(self, parameters: typing.List[torch.Tensor], loss_fn: callable) -> float: | |
parameters = list(parameters) # avoid parameters is a generator | |
self._backup_parameters(parameters) | |
best_step_size = self.line_search_impl(parameters=parameters,loss_fn=loss_fn) | |
self._restore_parameters(parameters) | |
self._release_store_prameters() | |
return best_step_size | |
def line_search_impl(self) -> float: | |
raise NotImplementedError("Subclasses must implement search_impl method") | |
class ArmijoLineSearch(BaseLineSearch): | |
def __init__(self, | |
beta: float = 0.5, | |
c1: float = 1.0e-4, | |
restart_coefficient: float = 10, | |
max_iterations: int = 100, | |
initial_step_size: float = 1.0, | |
minimal_step_size: float = 1.0e-8, | |
): | |
self.beta = beta | |
self.c1 = c1 | |
self.restart_coefficient = restart_coefficient | |
self.max_iterations = max_iterations | |
self.initial_step_size = initial_step_size | |
self.minimal_step_size = minimal_step_size | |
def _check_armijo_condition(self): | |
return self.new_loss <= self.current_loss - self.c1 * self.current_step_size * self.grad_norm_sq | |
def _compute_gradient_norm(self, parameters: typing.List[torch.Tensor]) -> float: | |
with torch.no_grad(): | |
grad_norm = [torch.norm(p.grad) ** 2 for p in parameters if p.grad is not None] | |
return torch.sum(torch.stack(grad_norm)).item() | |
def _update_current_step(self): | |
self.current_step_size *= self.beta | |
def line_search_impl(self, parameters: typing.List[torch.Tensor], loss_fn: callable) -> float: | |
self.current_step_size = self.initial_step_size | |
self.current_loss = loss_fn() | |
self.grad_norm_sq = self._compute_gradient_norm(parameters) | |
with torch.no_grad(): | |
for _ in range(self.max_iterations): | |
self._apply_gradient_update(self.current_step_size, parameters) | |
self.new_loss = loss_fn() | |
if self._check_armijo_condition(): | |
self.initial_step_size = self.restart_coefficient * self.current_step_size | |
return self.current_step_size | |
self._update_current_step() | |
self._restore_parameters(parameters) | |
return self.minimal_step_size | |
# 使用示例 | |
class SimpleModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.fc = nn.Linear(10, 1) | |
def forward(self, x): | |
return self.fc(x) | |
# 初始化组件 | |
model = SimpleModel() | |
criterion = nn.MSELoss() | |
armijo_line_searcher = ArmijoLineSearch() | |
optimizer = optim.SGD(model.parameters(), lr=1.0) # 初始学习率会被覆盖 | |
x_train = torch.randn(32, 10) | |
y_train = torch.randn(32, 1) | |
def loss_fn(model, criterion, x_train, y_train): | |
outputs = model(x_train) | |
loss = criterion(outputs, y_train) | |
return loss.item() | |
for epoch in range(100): | |
optimizer.zero_grad() | |
outputs = model(x_train) | |
loss = criterion(outputs, y_train) | |
loss.backward() | |
step_size = armijo_line_searcher( | |
list(model.parameters()), | |
partial(loss_fn, model, criterion, x_train, y_train) | |
) | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = step_size | |
optimizer.step() | |
with torch.no_grad(): | |
current_loss = criterion(model(x_train), y_train).item() | |
print(f"Epoch {epoch+1:2d} | Loss: {current_loss:.4f} | Step: {step_size:.4f}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment