Skip to content

Instantly share code, notes, and snippets.

@chenyaofo
Last active March 7, 2025 09:33
Show Gist options
  • Save chenyaofo/f0bbd3552f0ad72f76d546d17a6e34de to your computer and use it in GitHub Desktop.
Save chenyaofo/f0bbd3552f0ad72f76d546d17a6e34de to your computer and use it in GitHub Desktop.
Armijo Line Search
import torch
import torch.nn as nn
import torch.optim as optim
import typing
from functools import partial
class BaseLineSearch:
def _backup_parameters(self, parameters: typing.List[torch.Tensor]):
self.original_params = [p.detach().clone() for p in parameters]
def _restore_parameters(self, parameters: typing.List[torch.Tensor]):
with torch.no_grad():
for param, ori_param in zip(parameters, self.original_params):
param.data.copy_(ori_param)
def _apply_gradient_update(self, step_size, parameters: typing.List[torch.Tensor]):
with torch.no_grad():
for param in parameters:
if param.grad is not None:
param.data.sub_(step_size * param.grad)
def _release_store_prameters(self):
self.original_params.clear()
def __call__(self, parameters: typing.List[torch.Tensor], loss_fn: callable) -> float:
return self.line_search(parameters, loss_fn)
def line_search(self, parameters: typing.List[torch.Tensor], loss_fn: callable) -> float:
parameters = list(parameters) # avoid parameters is a generator
self._backup_parameters(parameters)
best_step_size = self.line_search_impl(parameters=parameters,loss_fn=loss_fn)
self._restore_parameters(parameters)
self._release_store_prameters()
return best_step_size
def line_search_impl(self) -> float:
raise NotImplementedError("Subclasses must implement search_impl method")
class ArmijoLineSearch(BaseLineSearch):
def __init__(self,
beta: float = 0.5,
c1: float = 1.0e-4,
restart_coefficient: float = 10,
max_iterations: int = 100,
initial_step_size: float = 1.0,
minimal_step_size: float = 1.0e-8,
):
self.beta = beta
self.c1 = c1
self.restart_coefficient = restart_coefficient
self.max_iterations = max_iterations
self.initial_step_size = initial_step_size
self.minimal_step_size = minimal_step_size
def _check_armijo_condition(self):
return self.new_loss <= self.current_loss - self.c1 * self.current_step_size * self.grad_norm_sq
def _compute_gradient_norm(self, parameters: typing.List[torch.Tensor]) -> float:
with torch.no_grad():
grad_norm = [torch.norm(p.grad) ** 2 for p in parameters if p.grad is not None]
return torch.sum(torch.stack(grad_norm)).item()
def _update_current_step(self):
self.current_step_size *= self.beta
def line_search_impl(self, parameters: typing.List[torch.Tensor], loss_fn: callable) -> float:
self.current_step_size = self.initial_step_size
self.current_loss = loss_fn()
self.grad_norm_sq = self._compute_gradient_norm(parameters)
with torch.no_grad():
for _ in range(self.max_iterations):
self._apply_gradient_update(self.current_step_size, parameters)
self.new_loss = loss_fn()
if self._check_armijo_condition():
self.initial_step_size = self.restart_coefficient * self.current_step_size
return self.current_step_size
self._update_current_step()
self._restore_parameters(parameters)
return self.minimal_step_size
# 使用示例
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# 初始化组件
model = SimpleModel()
criterion = nn.MSELoss()
armijo_line_searcher = ArmijoLineSearch()
optimizer = optim.SGD(model.parameters(), lr=1.0) # 初始学习率会被覆盖
x_train = torch.randn(32, 10)
y_train = torch.randn(32, 1)
def loss_fn(model, criterion, x_train, y_train):
outputs = model(x_train)
loss = criterion(outputs, y_train)
return loss.item()
for epoch in range(100):
optimizer.zero_grad()
outputs = model(x_train)
loss = criterion(outputs, y_train)
loss.backward()
step_size = armijo_line_searcher(
list(model.parameters()),
partial(loss_fn, model, criterion, x_train, y_train)
)
for param_group in optimizer.param_groups:
param_group['lr'] = step_size
optimizer.step()
with torch.no_grad():
current_loss = criterion(model(x_train), y_train).item()
print(f"Epoch {epoch+1:2d} | Loss: {current_loss:.4f} | Step: {step_size:.4f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment