Last active
April 20, 2025 01:16
-
-
Save yberreby/32e0add8027900693f21225f3e87eb1a to your computer and use it in GitHub Desktop.
Quick comparison of a few optimizers on the 2-simplex: GD, Adam, Mirror Descent, Adam in mirror space, LBFGS in mirror space
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env -S uv run --script | |
# /// script | |
# requires-python = ">=3.12" | |
# dependencies = [ | |
# "jax==0.5.0", | |
# "jaxopt>=0.8.5", | |
# "optax>=0.2.4", | |
# "matplotlib>=3.10.1", | |
# "pyqt6>=6.9.0", # for matplotlib gui | |
# ] | |
# /// | |
import logging | |
from typing import Any, Callable, Tuple, Optional, Sequence, Dict, NamedTuple | |
import numpy as np | |
import jax | |
import jax.numpy as jnp | |
import optax | |
from jaxopt import LBFGS | |
import matplotlib.pyplot as plt | |
from matplotlib.gridspec import GridSpec | |
# ========================================================================== | |
# Configuration Constants & Styles | |
# ========================================================================== | |
TARGET_PMF: jnp.ndarray = jnp.array([0.6, 0.3, 0.1]) | |
WEIGHTS: jnp.ndarray = jnp.array([.1, 50.0, 3.0]) | |
INITIAL_P2: jnp.ndarray = jnp.array([0.2, 0.7]) | |
LR_GD: float = 1.5e-2 | |
LR_GD_MIRROR: float = 7e-2 | |
LR_ADAM: float = 1.5e-2 | |
STEPS: int = 250 | |
MIN_LOSS: float = 1e-9 | |
LBFGS_OPTS: Dict[str, Any] = dict( | |
stepsize=0.1, | |
maxls=20, | |
history_size=10, | |
jit=True, | |
unroll=False, | |
maxiter=1 | |
) | |
GRID_PRIMAL: int = 200 | |
GRID_DUAL: int = 250 | |
DUAL_LIM: float = 4.0 | |
EPS_LOG: float = 1e-15 | |
EPS_LOSS: float = 1e-16 | |
LINE_WIDTH: float = 2.0 | |
class MethodStyle(NamedTuple): | |
color: str | |
name: str | |
colors = [ | |
"#000000", # Black | |
"#FF0000", # Bright Red | |
"#008000", # Normal Green | |
"#0000FF", # Bright Blue | |
"#FFA500", # Orange | |
"#00FFFF", # Cyan | |
] | |
METHOD_STYLES: Dict[str, MethodStyle] = { | |
"GD+Proj": MethodStyle(colors[0], "GD + Proj"), | |
"Adam+Proj": MethodStyle(colors[1], "Adam + Proj"), | |
"Mirror(GD)": MethodStyle(colors[2], "Mirror Descent (GD)"), | |
"AdamHybrid": MethodStyle(colors[3], "Adam Hybrid + KL"), | |
"LBFGS(KL)": MethodStyle(colors[4], "LBFGS (KL)"), | |
} | |
METHOD_ORDER: Sequence[str] = list(METHOD_STYLES.keys()) | |
START_MARKER: Dict[str, Any] = {'marker':'o', 'ms':9, 'mfc':'lime', 'mec':'black', 'label':'Start', 'ls':'none', 'zorder': 10} | |
END_MARKER: Dict[str, Any] = {'marker':'X', 'ms':9, 'mfc':'cyan', 'mec':'black', 'label':'End', 'ls':'none', 'zorder': 10} | |
OPTIMUM_MARKER: Dict[str, Any] = {'marker':'*', 'ms':15, 'mfc':'red', 'mec':'white', 'label':'Optimum $q$', 'ls':'none', 'zorder': 10} | |
# ========================================================================== | |
# Logging Setup | |
# ========================================================================== | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S" | |
) | |
# ========================================================================== | |
# Core Mathematical Functions | |
# ========================================================================== | |
@jax.jit | |
def objective(p: jnp.ndarray) -> jnp.ndarray: | |
return jnp.sum(WEIGHTS * (p - TARGET_PMF) ** 2) | |
F_MIN: float = float(objective(TARGET_PMF)) | |
grad_objective: Callable[[jnp.ndarray], jnp.ndarray] = jax.jit(jax.grad(objective)) | |
def lift_to_3d(p2: jnp.ndarray) -> jnp.ndarray: | |
p3 = 1.0 - jnp.sum(p2) | |
return jnp.array([p2[0], p2[1], p3]) | |
@jax.jit | |
def project_to_simplex(v: jnp.ndarray) -> jnp.ndarray: | |
num_dims = len(v) | |
u = jnp.sort(v)[::-1] | |
cssv = jnp.cumsum(u) | |
rho_candidates = u * jnp.arange(1, num_dims + 1) > (cssv - 1) | |
rho = jnp.where(rho_candidates, jnp.arange(1, num_dims + 1), 0).max() | |
rho = jnp.maximum(1, rho) | |
theta = (cssv[rho - 1] - 1.0) / rho | |
w = jnp.maximum(v - theta, 0) | |
return w | |
@jax.jit | |
def project_p2_on_simplex(p2_candidate: jnp.ndarray) -> jnp.ndarray: | |
p3_candidate = 1.0 - jnp.sum(p2_candidate) | |
v3 = jnp.array([p2_candidate[0], p2_candidate[1], p3_candidate]) | |
p3_projected = project_to_simplex(v3) | |
return p3_projected[:2] | |
@jax.jit | |
def log_map(p: jnp.ndarray) -> jnp.ndarray: | |
return jnp.log(p + EPS_LOG) | |
@jax.jit | |
def exp_map(y: jnp.ndarray) -> jnp.ndarray: | |
return jax.nn.softmax(y) | |
@jax.jit | |
def objective_p2(p2: jnp.ndarray) -> jnp.ndarray: | |
return objective(lift_to_3d(p2)) | |
grad_objective_p2: Callable[[jnp.ndarray], jnp.ndarray] = jax.jit(jax.grad(objective_p2)) | |
@jax.jit | |
def objective_logits(s: jnp.ndarray) -> jnp.ndarray: | |
return objective(exp_map(s)) | |
@jax.jit | |
def objective_log_ratio(s_prime: jnp.ndarray) -> jnp.ndarray: | |
s_full = jnp.array([s_prime[0], s_prime[1], 0.0]) | |
return objective(exp_map(s_full)) | |
def convert_p2_to_log_ratio(traj_p2: np.ndarray) -> np.ndarray: | |
p1 = traj_p2[:, 0] | |
p2 = traj_p2[:, 1] | |
p3 = np.clip(1.0 - p1 - p2, EPS_LOG, None) | |
p1_safe = np.clip(p1, EPS_LOG, None) | |
p2_safe = np.clip(p2, EPS_LOG, None) | |
s1_prime = np.log(p1_safe / p3) | |
s2_prime = np.log(p2_safe / p3) | |
return np.stack([s1_prime, s2_prime], axis=-1) | |
# ========================================================================== | |
# Optimizer Core Update Logic | |
# ========================================================================== | |
@jax.jit | |
def _gd_proj_update(p2: jnp.ndarray, learning_rate: float) -> jnp.ndarray: | |
grad_p2 = grad_objective_p2(p2) | |
p2_candidate = p2 - learning_rate * grad_p2 | |
return project_p2_on_simplex(p2_candidate) | |
def _optax_proj_step( | |
p2: jnp.ndarray, | |
state: Any, | |
optimizer: optax.GradientTransformation, | |
) -> Tuple[jnp.ndarray, Any]: | |
grad_p2 = grad_objective_p2(p2) | |
updates, new_state = optimizer.update(grad_p2, state, p2) | |
p2_candidate = p2 + updates | |
p2_next = project_p2_on_simplex(p2_candidate) | |
return p2_next, new_state | |
@jax.jit | |
def _gd_mirror_update(p2: jnp.ndarray, learning_rate: float) -> jnp.ndarray: | |
p_full = lift_to_3d(p2) | |
grad_full = grad_objective(p_full) | |
y_next = log_map(p_full) - learning_rate * grad_full | |
p_next_full = exp_map(y_next) | |
return p_next_full[:2] | |
def _optax_mirror_hybrid_step( | |
p2: jnp.ndarray, | |
state: Any, | |
optimizer: optax.GradientTransformation, | |
) -> Tuple[jnp.ndarray, Any]: | |
p_full = lift_to_3d(p2) | |
grad_full = grad_objective(p_full) | |
updates, new_state = optimizer.update(grad_full, state, p_full) | |
y_next = log_map(p_full) + updates | |
p_next_full = exp_map(y_next) | |
return p_next_full[:2], new_state | |
# ========================================================================== | |
# Unified Run Loop | |
# ========================================================================== | |
OptimizerStep = Callable[[jnp.ndarray, Any], Tuple[jnp.ndarray, Any]] | |
def run_optimizer( | |
optimizer_name: str, | |
step_func: OptimizerStep, | |
initial_params: jnp.ndarray, | |
initial_state: Optional[Any], | |
num_steps: int, | |
) -> Tuple[np.ndarray, np.ndarray]: | |
params = initial_params | |
state = initial_state | |
is_p2_space = initial_params.shape == (2,) | |
loss_func = objective_p2 if is_p2_space else objective_logits | |
trajectory = [np.array(params)] | |
losses = [float(loss_func(params))] | |
logging.info(f"Running {optimizer_name}...") | |
for step in range(num_steps): | |
params, state = step_func(params, state) | |
trajectory.append(np.array(params)) | |
losses.append(float(loss_func(params))) | |
if step % 50 == 0 or step == num_steps - 1: | |
logging.debug(f"Step {step+1}/{num_steps}, Loss: {losses[-1]:.6f}") | |
return np.stack(trajectory), np.array(losses) | |
# ========================================================================== | |
# Prepare and Run Optimizers | |
# ========================================================================== | |
all_results: Dict[str, Tuple[np.ndarray, np.ndarray]] = {} | |
def gd_proj_runner(p2: jnp.ndarray, state: None) -> Tuple[jnp.ndarray, None]: | |
return _gd_proj_update(p2, LR_GD), None | |
def gd_mirror_runner(p2: jnp.ndarray, state: None) -> Tuple[jnp.ndarray, None]: | |
return _gd_mirror_update(p2, LR_GD_MIRROR), None | |
adam_opt = optax.adam(LR_ADAM) | |
def adam_proj_runner(p2: jnp.ndarray, state: Any) -> Tuple[jnp.ndarray, Any]: | |
return _optax_proj_step(p2, state, adam_opt) | |
def adam_hybrid_runner(p2: jnp.ndarray, state: Any) -> Tuple[jnp.ndarray, Any]: | |
return _optax_mirror_hybrid_step(p2, state, adam_opt) | |
lbfgs_logit_solver = LBFGS(fun=objective_logits, value_and_grad=False, **LBFGS_OPTS) | |
lbfgs_logit_runner = lbfgs_logit_solver.update | |
state_adam_proj = adam_opt.init(INITIAL_P2) | |
state_adam_hybrid = adam_opt.init(lift_to_3d(INITIAL_P2)) | |
s0 = log_map(lift_to_3d(INITIAL_P2)) | |
state_lbfgs_logit = lbfgs_logit_solver.init_state(s0) | |
optimizers_to_run: Dict[str, Tuple[OptimizerStep, jnp.ndarray, Optional[Any]]] = { | |
"GD+Proj": (gd_proj_runner, INITIAL_P2, None), | |
"Adam+Proj": (adam_proj_runner, INITIAL_P2, state_adam_proj), | |
"Mirror(GD)": (gd_mirror_runner, INITIAL_P2, None), | |
"AdamHybrid": (adam_hybrid_runner, INITIAL_P2, state_adam_hybrid), | |
"LBFGS(KL)": (lbfgs_logit_runner, s0, state_lbfgs_logit), | |
} | |
for name in METHOD_ORDER: | |
style = METHOD_STYLES[name] | |
runner_func, init_params, init_state = optimizers_to_run[name] | |
traj_params, loss = run_optimizer( | |
style.name, runner_func, init_params, init_state, STEPS | |
) | |
all_results[name] = (traj_params, loss) | |
# ========================================================================== | |
# Prepare Data for Plotting | |
# ========================================================================== | |
primal_trajectories_p2: Dict[str, np.ndarray] = {} | |
loss_data: Dict[str, np.ndarray] = {} | |
dual_trajectories_s_prime: Dict[str, np.ndarray] = {} | |
for name in METHOD_ORDER: | |
traj_params, loss = all_results[name] | |
loss_data[name] = loss | |
if name == "LBFGS(KL)": | |
primal_trajectories_p2[name] = np.array(jax.vmap(exp_map)(traj_params)[:, :2]) | |
else: | |
primal_trajectories_p2[name] = np.array(traj_params) | |
loss_below_threshold = loss < MIN_LOSS | |
if np.any(loss_below_threshold): | |
first_below_idx = np.argmax(loss_below_threshold) | |
primal_trajectories_p2[name][first_below_idx+1:] = np.nan | |
loss[first_below_idx+1:] = np.nan | |
dual_trajectories_s_prime[name] = convert_p2_to_log_ratio(primal_trajectories_p2[name]) | |
plot_styles_ordered = [METHOD_STYLES[name] for name in METHOD_ORDER] | |
primal_traj_list = [primal_trajectories_p2[name] for name in METHOD_ORDER] | |
dual_traj_list = [dual_trajectories_s_prime[name] for name in METHOD_ORDER] | |
loss_list = [loss_data[name] for name in METHOD_ORDER] | |
# ========================================================================== | |
# Create Contour Grids | |
# ========================================================================== | |
@jax.jit | |
def _calculate_primal_z(p2_grid): | |
return jax.vmap(objective_p2)(p2_grid) | |
@jax.jit | |
def _calculate_dual_z(s_prime_grid): | |
return jax.vmap(objective_log_ratio)(s_prime_grid) | |
def create_primal_grid_data() -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
logging.info("Creating primal contour grid...") | |
xs = np.linspace(0, 1, GRID_PRIMAL) | |
ys = np.linspace(0, 1, GRID_PRIMAL) | |
X, Y = np.meshgrid(xs, ys) | |
mask = (X + Y) <= 1.0 | |
p2_grid = jnp.stack([X[mask], Y[mask]], axis=-1) | |
Z_masked = _calculate_primal_z(p2_grid) | |
Z = np.full_like(X, np.nan, dtype=float) | |
Z[mask] = np.array(Z_masked) | |
return X, Y, Z | |
def create_dual_grid_data() -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
logging.info("Creating dual contour grid...") | |
s1_coords = np.linspace(-DUAL_LIM, DUAL_LIM * 2, GRID_DUAL) | |
s2_coords = np.linspace(-DUAL_LIM, DUAL_LIM * 2, GRID_DUAL) | |
S1, S2 = np.meshgrid(s1_coords, s2_coords) | |
S_prime_grid = jnp.stack([S1.ravel(), S2.ravel()], axis=-1) | |
Z_dual_flat = _calculate_dual_z(S_prime_grid) | |
Z_dual = np.array(Z_dual_flat).reshape(S1.shape) | |
return S1, S2, Z_dual | |
X_prim, Y_prim, Z_prim = create_primal_grid_data() | |
S1_dual, S2_dual, Z_dual = create_dual_grid_data() | |
# ========================================================================== | |
# Generate Plots | |
# ========================================================================== | |
plt.style.use('seaborn-v0_8-darkgrid') | |
plt.rcParams['lines.markersize'] = 10 | |
logging.info("Generating Figure 1...") | |
n_opts = len(METHOD_ORDER) | |
n_cols_fig1 = 3 | |
n_rows_fig1_traj = (n_opts + n_cols_fig1 - 1) // n_cols_fig1 | |
fig1_height_ratios = [1] * n_rows_fig1_traj + [0.7] | |
fig1 = plt.figure(figsize=(4.5 * n_cols_fig1, 4 * n_rows_fig1_traj + 3)) | |
gs1 = GridSpec(n_rows_fig1_traj + 1, n_cols_fig1, figure=fig1, | |
height_ratios=fig1_height_ratios, hspace=0.3, wspace=0.2) | |
for idx, name in enumerate(METHOD_ORDER): | |
row, col = divmod(idx, n_cols_fig1) | |
ax = fig1.add_subplot(gs1[row, col]) | |
style = METHOD_STYLES[name] | |
traj = primal_traj_list[idx] | |
ax.contourf(X_prim, Y_prim, Z_prim, levels=30, cmap="viridis", alpha=0.6) | |
ax.plot(traj[:, 0], traj[:, 1], | |
color=style.color, | |
linestyle='-', | |
lw=LINE_WIDTH, | |
marker='.', | |
ms=8) | |
ax.plot(traj[0, 0], traj[0, 1], **START_MARKER) | |
ax.plot(traj[-1, 0], traj[-1, 1], **END_MARKER) | |
ax.plot(TARGET_PMF[0], TARGET_PMF[1], **OPTIMUM_MARKER) | |
ax.set_title(style.name, fontsize=10, wrap=True) | |
ax.set_xticks([]); ax.set_yticks([]) | |
ax.set_xlim(0, 1); ax.set_ylim(0, 1) | |
ax.set_aspect('equal', adjustable='box') | |
ax_loss = fig1.add_subplot(gs1[n_rows_fig1_traj, :]) | |
for i, name in enumerate(METHOD_ORDER): | |
style = METHOD_STYLES[name] | |
loss = loss_list[i] | |
suboptimality = np.maximum(loss - F_MIN, EPS_LOSS) | |
ax_loss.semilogy(np.arange(len(loss)), suboptimality, | |
label=style.name, | |
color=style.color, | |
linestyle='-', | |
lw=LINE_WIDTH) | |
ax_loss.set_xlabel("Iteration") | |
ax_loss.set_ylabel(r"$f(p_t) - f(q)$ (log scale)") | |
ax_loss.set_title("Suboptimality over Iterations") | |
ax_loss.legend(ncol=min(3, n_opts), fontsize='small', loc='upper right', frameon=True, fancybox=True, edgecolor='black', shadow=True) | |
ax_loss.grid(True, which="both", ls=":", alpha=0.7) | |
ax_loss.set_xlim(0, STEPS) | |
ax_loss.set_ylim(MIN_LOSS, None) | |
plt.show() | |
logging.info("Generating Figure 2...") | |
fig2 = plt.figure(figsize=(14, 6.5)) | |
gs2 = GridSpec(1, 2, figure=fig2, width_ratios=[1, 1.15]) | |
axp = fig2.add_subplot(gs2[0]) | |
contour_primal = axp.contourf(X_prim, Y_prim, Z_prim, levels=30, cmap="viridis", alpha=0.6) | |
for i, name in enumerate(METHOD_ORDER): | |
style = METHOD_STYLES[name] | |
traj = primal_traj_list[i] | |
axp.plot(traj[:, 0], traj[:, 1], | |
marker='.', ms=7, | |
linestyle='-', | |
lw=LINE_WIDTH, | |
label=style.name, | |
color=style.color) | |
axp.plot(*INITIAL_P2, **START_MARKER) | |
axp.plot(*TARGET_PMF[:2], **OPTIMUM_MARKER) | |
axp.set_xlabel(r"$p_1$") | |
axp.set_ylabel(r"$p_2$") | |
axp.set_title(r"Primal Space ($\Delta^2$) Overlay") | |
axp.legend(fontsize='small', loc='upper right') | |
axp.set_xlim(0, 1); axp.set_ylim(0, 1) | |
axp.set_aspect('equal', adjustable='box') | |
axp.grid(True, which="both", ls=":", alpha=0.6) | |
axd = fig2.add_subplot(gs2[1]) | |
contour_dual = axd.contourf(S1_dual, S2_dual, Z_dual, levels=40, cmap="viridis", alpha=0.7) | |
fig2.colorbar(contour_dual, ax=axd, shrink=0.9, label=r"$f(p)$ in dual coords") | |
for i, name in enumerate(METHOD_ORDER): | |
style = METHOD_STYLES[name] | |
traj = dual_traj_list[i] | |
axd.plot(traj[:, 0], traj[:, 1], | |
marker='.', ms=7, | |
linestyle='-', | |
lw=LINE_WIDTH, | |
label=style.name, | |
color=style.color) | |
start_dual = convert_p2_to_log_ratio(np.array([INITIAL_P2]))[0] | |
optimum_dual = convert_p2_to_log_ratio(np.array([TARGET_PMF[:2]]))[0] | |
axd.plot(*start_dual, **START_MARKER) | |
axd.plot(*optimum_dual, **OPTIMUM_MARKER) | |
axd.set_xlabel(r"$s'_1 = \log(p_1/p_3)$") | |
axd.set_ylabel(r"$s'_2 = \log(p_2/p_3)$") | |
axd.set_title(r"Dual Log-Ratio Space ($\mathbb{R}^2$) Overlay") | |
axd.set_xlim(-DUAL_LIM, DUAL_LIM * 2) | |
axd.set_ylim(-DUAL_LIM, DUAL_LIM * 2) | |
axd.legend(fontsize='small', loc='upper right') | |
axd.grid(True, which="both", ls=":", alpha=0.6) | |
plt.show() | |
output_file1 = "figure1_primal_loss.png" | |
output_file2 = "figure2_primal_dual.png" | |
fig1.savefig(output_file1, dpi=80, bbox_inches="tight") | |
fig2.savefig(output_file2, dpi=80, bbox_inches="tight") | |
print(f"Figure 1 saved at: {output_file1}") | |
print(f"Figure 2 saved at: {output_file2}") | |
logging.info("Script finished successfully.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment