Skip to content

Instantly share code, notes, and snippets.

@yberreby
Last active April 20, 2025 01:16
Show Gist options
  • Save yberreby/32e0add8027900693f21225f3e87eb1a to your computer and use it in GitHub Desktop.
Save yberreby/32e0add8027900693f21225f3e87eb1a to your computer and use it in GitHub Desktop.
Quick comparison of a few optimizers on the 2-simplex: GD, Adam, Mirror Descent, Adam in mirror space, LBFGS in mirror space
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "jax==0.5.0",
# "jaxopt>=0.8.5",
# "optax>=0.2.4",
# "matplotlib>=3.10.1",
# "pyqt6>=6.9.0", # for matplotlib gui
# ]
# ///
import logging
from typing import Any, Callable, Tuple, Optional, Sequence, Dict, NamedTuple
import numpy as np
import jax
import jax.numpy as jnp
import optax
from jaxopt import LBFGS
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
# ==========================================================================
# Configuration Constants & Styles
# ==========================================================================
TARGET_PMF: jnp.ndarray = jnp.array([0.6, 0.3, 0.1])
WEIGHTS: jnp.ndarray = jnp.array([.1, 50.0, 3.0])
INITIAL_P2: jnp.ndarray = jnp.array([0.2, 0.7])
LR_GD: float = 1.5e-2
LR_GD_MIRROR: float = 7e-2
LR_ADAM: float = 1.5e-2
STEPS: int = 250
MIN_LOSS: float = 1e-9
LBFGS_OPTS: Dict[str, Any] = dict(
stepsize=0.1,
maxls=20,
history_size=10,
jit=True,
unroll=False,
maxiter=1
)
GRID_PRIMAL: int = 200
GRID_DUAL: int = 250
DUAL_LIM: float = 4.0
EPS_LOG: float = 1e-15
EPS_LOSS: float = 1e-16
LINE_WIDTH: float = 2.0
class MethodStyle(NamedTuple):
color: str
name: str
colors = [
"#000000", # Black
"#FF0000", # Bright Red
"#008000", # Normal Green
"#0000FF", # Bright Blue
"#FFA500", # Orange
"#00FFFF", # Cyan
]
METHOD_STYLES: Dict[str, MethodStyle] = {
"GD+Proj": MethodStyle(colors[0], "GD + Proj"),
"Adam+Proj": MethodStyle(colors[1], "Adam + Proj"),
"Mirror(GD)": MethodStyle(colors[2], "Mirror Descent (GD)"),
"AdamHybrid": MethodStyle(colors[3], "Adam Hybrid + KL"),
"LBFGS(KL)": MethodStyle(colors[4], "LBFGS (KL)"),
}
METHOD_ORDER: Sequence[str] = list(METHOD_STYLES.keys())
START_MARKER: Dict[str, Any] = {'marker':'o', 'ms':9, 'mfc':'lime', 'mec':'black', 'label':'Start', 'ls':'none', 'zorder': 10}
END_MARKER: Dict[str, Any] = {'marker':'X', 'ms':9, 'mfc':'cyan', 'mec':'black', 'label':'End', 'ls':'none', 'zorder': 10}
OPTIMUM_MARKER: Dict[str, Any] = {'marker':'*', 'ms':15, 'mfc':'red', 'mec':'white', 'label':'Optimum $q$', 'ls':'none', 'zorder': 10}
# ==========================================================================
# Logging Setup
# ==========================================================================
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S"
)
# ==========================================================================
# Core Mathematical Functions
# ==========================================================================
@jax.jit
def objective(p: jnp.ndarray) -> jnp.ndarray:
return jnp.sum(WEIGHTS * (p - TARGET_PMF) ** 2)
F_MIN: float = float(objective(TARGET_PMF))
grad_objective: Callable[[jnp.ndarray], jnp.ndarray] = jax.jit(jax.grad(objective))
def lift_to_3d(p2: jnp.ndarray) -> jnp.ndarray:
p3 = 1.0 - jnp.sum(p2)
return jnp.array([p2[0], p2[1], p3])
@jax.jit
def project_to_simplex(v: jnp.ndarray) -> jnp.ndarray:
num_dims = len(v)
u = jnp.sort(v)[::-1]
cssv = jnp.cumsum(u)
rho_candidates = u * jnp.arange(1, num_dims + 1) > (cssv - 1)
rho = jnp.where(rho_candidates, jnp.arange(1, num_dims + 1), 0).max()
rho = jnp.maximum(1, rho)
theta = (cssv[rho - 1] - 1.0) / rho
w = jnp.maximum(v - theta, 0)
return w
@jax.jit
def project_p2_on_simplex(p2_candidate: jnp.ndarray) -> jnp.ndarray:
p3_candidate = 1.0 - jnp.sum(p2_candidate)
v3 = jnp.array([p2_candidate[0], p2_candidate[1], p3_candidate])
p3_projected = project_to_simplex(v3)
return p3_projected[:2]
@jax.jit
def log_map(p: jnp.ndarray) -> jnp.ndarray:
return jnp.log(p + EPS_LOG)
@jax.jit
def exp_map(y: jnp.ndarray) -> jnp.ndarray:
return jax.nn.softmax(y)
@jax.jit
def objective_p2(p2: jnp.ndarray) -> jnp.ndarray:
return objective(lift_to_3d(p2))
grad_objective_p2: Callable[[jnp.ndarray], jnp.ndarray] = jax.jit(jax.grad(objective_p2))
@jax.jit
def objective_logits(s: jnp.ndarray) -> jnp.ndarray:
return objective(exp_map(s))
@jax.jit
def objective_log_ratio(s_prime: jnp.ndarray) -> jnp.ndarray:
s_full = jnp.array([s_prime[0], s_prime[1], 0.0])
return objective(exp_map(s_full))
def convert_p2_to_log_ratio(traj_p2: np.ndarray) -> np.ndarray:
p1 = traj_p2[:, 0]
p2 = traj_p2[:, 1]
p3 = np.clip(1.0 - p1 - p2, EPS_LOG, None)
p1_safe = np.clip(p1, EPS_LOG, None)
p2_safe = np.clip(p2, EPS_LOG, None)
s1_prime = np.log(p1_safe / p3)
s2_prime = np.log(p2_safe / p3)
return np.stack([s1_prime, s2_prime], axis=-1)
# ==========================================================================
# Optimizer Core Update Logic
# ==========================================================================
@jax.jit
def _gd_proj_update(p2: jnp.ndarray, learning_rate: float) -> jnp.ndarray:
grad_p2 = grad_objective_p2(p2)
p2_candidate = p2 - learning_rate * grad_p2
return project_p2_on_simplex(p2_candidate)
def _optax_proj_step(
p2: jnp.ndarray,
state: Any,
optimizer: optax.GradientTransformation,
) -> Tuple[jnp.ndarray, Any]:
grad_p2 = grad_objective_p2(p2)
updates, new_state = optimizer.update(grad_p2, state, p2)
p2_candidate = p2 + updates
p2_next = project_p2_on_simplex(p2_candidate)
return p2_next, new_state
@jax.jit
def _gd_mirror_update(p2: jnp.ndarray, learning_rate: float) -> jnp.ndarray:
p_full = lift_to_3d(p2)
grad_full = grad_objective(p_full)
y_next = log_map(p_full) - learning_rate * grad_full
p_next_full = exp_map(y_next)
return p_next_full[:2]
def _optax_mirror_hybrid_step(
p2: jnp.ndarray,
state: Any,
optimizer: optax.GradientTransformation,
) -> Tuple[jnp.ndarray, Any]:
p_full = lift_to_3d(p2)
grad_full = grad_objective(p_full)
updates, new_state = optimizer.update(grad_full, state, p_full)
y_next = log_map(p_full) + updates
p_next_full = exp_map(y_next)
return p_next_full[:2], new_state
# ==========================================================================
# Unified Run Loop
# ==========================================================================
OptimizerStep = Callable[[jnp.ndarray, Any], Tuple[jnp.ndarray, Any]]
def run_optimizer(
optimizer_name: str,
step_func: OptimizerStep,
initial_params: jnp.ndarray,
initial_state: Optional[Any],
num_steps: int,
) -> Tuple[np.ndarray, np.ndarray]:
params = initial_params
state = initial_state
is_p2_space = initial_params.shape == (2,)
loss_func = objective_p2 if is_p2_space else objective_logits
trajectory = [np.array(params)]
losses = [float(loss_func(params))]
logging.info(f"Running {optimizer_name}...")
for step in range(num_steps):
params, state = step_func(params, state)
trajectory.append(np.array(params))
losses.append(float(loss_func(params)))
if step % 50 == 0 or step == num_steps - 1:
logging.debug(f"Step {step+1}/{num_steps}, Loss: {losses[-1]:.6f}")
return np.stack(trajectory), np.array(losses)
# ==========================================================================
# Prepare and Run Optimizers
# ==========================================================================
all_results: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
def gd_proj_runner(p2: jnp.ndarray, state: None) -> Tuple[jnp.ndarray, None]:
return _gd_proj_update(p2, LR_GD), None
def gd_mirror_runner(p2: jnp.ndarray, state: None) -> Tuple[jnp.ndarray, None]:
return _gd_mirror_update(p2, LR_GD_MIRROR), None
adam_opt = optax.adam(LR_ADAM)
def adam_proj_runner(p2: jnp.ndarray, state: Any) -> Tuple[jnp.ndarray, Any]:
return _optax_proj_step(p2, state, adam_opt)
def adam_hybrid_runner(p2: jnp.ndarray, state: Any) -> Tuple[jnp.ndarray, Any]:
return _optax_mirror_hybrid_step(p2, state, adam_opt)
lbfgs_logit_solver = LBFGS(fun=objective_logits, value_and_grad=False, **LBFGS_OPTS)
lbfgs_logit_runner = lbfgs_logit_solver.update
state_adam_proj = adam_opt.init(INITIAL_P2)
state_adam_hybrid = adam_opt.init(lift_to_3d(INITIAL_P2))
s0 = log_map(lift_to_3d(INITIAL_P2))
state_lbfgs_logit = lbfgs_logit_solver.init_state(s0)
optimizers_to_run: Dict[str, Tuple[OptimizerStep, jnp.ndarray, Optional[Any]]] = {
"GD+Proj": (gd_proj_runner, INITIAL_P2, None),
"Adam+Proj": (adam_proj_runner, INITIAL_P2, state_adam_proj),
"Mirror(GD)": (gd_mirror_runner, INITIAL_P2, None),
"AdamHybrid": (adam_hybrid_runner, INITIAL_P2, state_adam_hybrid),
"LBFGS(KL)": (lbfgs_logit_runner, s0, state_lbfgs_logit),
}
for name in METHOD_ORDER:
style = METHOD_STYLES[name]
runner_func, init_params, init_state = optimizers_to_run[name]
traj_params, loss = run_optimizer(
style.name, runner_func, init_params, init_state, STEPS
)
all_results[name] = (traj_params, loss)
# ==========================================================================
# Prepare Data for Plotting
# ==========================================================================
primal_trajectories_p2: Dict[str, np.ndarray] = {}
loss_data: Dict[str, np.ndarray] = {}
dual_trajectories_s_prime: Dict[str, np.ndarray] = {}
for name in METHOD_ORDER:
traj_params, loss = all_results[name]
loss_data[name] = loss
if name == "LBFGS(KL)":
primal_trajectories_p2[name] = np.array(jax.vmap(exp_map)(traj_params)[:, :2])
else:
primal_trajectories_p2[name] = np.array(traj_params)
loss_below_threshold = loss < MIN_LOSS
if np.any(loss_below_threshold):
first_below_idx = np.argmax(loss_below_threshold)
primal_trajectories_p2[name][first_below_idx+1:] = np.nan
loss[first_below_idx+1:] = np.nan
dual_trajectories_s_prime[name] = convert_p2_to_log_ratio(primal_trajectories_p2[name])
plot_styles_ordered = [METHOD_STYLES[name] for name in METHOD_ORDER]
primal_traj_list = [primal_trajectories_p2[name] for name in METHOD_ORDER]
dual_traj_list = [dual_trajectories_s_prime[name] for name in METHOD_ORDER]
loss_list = [loss_data[name] for name in METHOD_ORDER]
# ==========================================================================
# Create Contour Grids
# ==========================================================================
@jax.jit
def _calculate_primal_z(p2_grid):
return jax.vmap(objective_p2)(p2_grid)
@jax.jit
def _calculate_dual_z(s_prime_grid):
return jax.vmap(objective_log_ratio)(s_prime_grid)
def create_primal_grid_data() -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
logging.info("Creating primal contour grid...")
xs = np.linspace(0, 1, GRID_PRIMAL)
ys = np.linspace(0, 1, GRID_PRIMAL)
X, Y = np.meshgrid(xs, ys)
mask = (X + Y) <= 1.0
p2_grid = jnp.stack([X[mask], Y[mask]], axis=-1)
Z_masked = _calculate_primal_z(p2_grid)
Z = np.full_like(X, np.nan, dtype=float)
Z[mask] = np.array(Z_masked)
return X, Y, Z
def create_dual_grid_data() -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
logging.info("Creating dual contour grid...")
s1_coords = np.linspace(-DUAL_LIM, DUAL_LIM * 2, GRID_DUAL)
s2_coords = np.linspace(-DUAL_LIM, DUAL_LIM * 2, GRID_DUAL)
S1, S2 = np.meshgrid(s1_coords, s2_coords)
S_prime_grid = jnp.stack([S1.ravel(), S2.ravel()], axis=-1)
Z_dual_flat = _calculate_dual_z(S_prime_grid)
Z_dual = np.array(Z_dual_flat).reshape(S1.shape)
return S1, S2, Z_dual
X_prim, Y_prim, Z_prim = create_primal_grid_data()
S1_dual, S2_dual, Z_dual = create_dual_grid_data()
# ==========================================================================
# Generate Plots
# ==========================================================================
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['lines.markersize'] = 10
logging.info("Generating Figure 1...")
n_opts = len(METHOD_ORDER)
n_cols_fig1 = 3
n_rows_fig1_traj = (n_opts + n_cols_fig1 - 1) // n_cols_fig1
fig1_height_ratios = [1] * n_rows_fig1_traj + [0.7]
fig1 = plt.figure(figsize=(4.5 * n_cols_fig1, 4 * n_rows_fig1_traj + 3))
gs1 = GridSpec(n_rows_fig1_traj + 1, n_cols_fig1, figure=fig1,
height_ratios=fig1_height_ratios, hspace=0.3, wspace=0.2)
for idx, name in enumerate(METHOD_ORDER):
row, col = divmod(idx, n_cols_fig1)
ax = fig1.add_subplot(gs1[row, col])
style = METHOD_STYLES[name]
traj = primal_traj_list[idx]
ax.contourf(X_prim, Y_prim, Z_prim, levels=30, cmap="viridis", alpha=0.6)
ax.plot(traj[:, 0], traj[:, 1],
color=style.color,
linestyle='-',
lw=LINE_WIDTH,
marker='.',
ms=8)
ax.plot(traj[0, 0], traj[0, 1], **START_MARKER)
ax.plot(traj[-1, 0], traj[-1, 1], **END_MARKER)
ax.plot(TARGET_PMF[0], TARGET_PMF[1], **OPTIMUM_MARKER)
ax.set_title(style.name, fontsize=10, wrap=True)
ax.set_xticks([]); ax.set_yticks([])
ax.set_xlim(0, 1); ax.set_ylim(0, 1)
ax.set_aspect('equal', adjustable='box')
ax_loss = fig1.add_subplot(gs1[n_rows_fig1_traj, :])
for i, name in enumerate(METHOD_ORDER):
style = METHOD_STYLES[name]
loss = loss_list[i]
suboptimality = np.maximum(loss - F_MIN, EPS_LOSS)
ax_loss.semilogy(np.arange(len(loss)), suboptimality,
label=style.name,
color=style.color,
linestyle='-',
lw=LINE_WIDTH)
ax_loss.set_xlabel("Iteration")
ax_loss.set_ylabel(r"$f(p_t) - f(q)$ (log scale)")
ax_loss.set_title("Suboptimality over Iterations")
ax_loss.legend(ncol=min(3, n_opts), fontsize='small', loc='upper right', frameon=True, fancybox=True, edgecolor='black', shadow=True)
ax_loss.grid(True, which="both", ls=":", alpha=0.7)
ax_loss.set_xlim(0, STEPS)
ax_loss.set_ylim(MIN_LOSS, None)
plt.show()
logging.info("Generating Figure 2...")
fig2 = plt.figure(figsize=(14, 6.5))
gs2 = GridSpec(1, 2, figure=fig2, width_ratios=[1, 1.15])
axp = fig2.add_subplot(gs2[0])
contour_primal = axp.contourf(X_prim, Y_prim, Z_prim, levels=30, cmap="viridis", alpha=0.6)
for i, name in enumerate(METHOD_ORDER):
style = METHOD_STYLES[name]
traj = primal_traj_list[i]
axp.plot(traj[:, 0], traj[:, 1],
marker='.', ms=7,
linestyle='-',
lw=LINE_WIDTH,
label=style.name,
color=style.color)
axp.plot(*INITIAL_P2, **START_MARKER)
axp.plot(*TARGET_PMF[:2], **OPTIMUM_MARKER)
axp.set_xlabel(r"$p_1$")
axp.set_ylabel(r"$p_2$")
axp.set_title(r"Primal Space ($\Delta^2$) Overlay")
axp.legend(fontsize='small', loc='upper right')
axp.set_xlim(0, 1); axp.set_ylim(0, 1)
axp.set_aspect('equal', adjustable='box')
axp.grid(True, which="both", ls=":", alpha=0.6)
axd = fig2.add_subplot(gs2[1])
contour_dual = axd.contourf(S1_dual, S2_dual, Z_dual, levels=40, cmap="viridis", alpha=0.7)
fig2.colorbar(contour_dual, ax=axd, shrink=0.9, label=r"$f(p)$ in dual coords")
for i, name in enumerate(METHOD_ORDER):
style = METHOD_STYLES[name]
traj = dual_traj_list[i]
axd.plot(traj[:, 0], traj[:, 1],
marker='.', ms=7,
linestyle='-',
lw=LINE_WIDTH,
label=style.name,
color=style.color)
start_dual = convert_p2_to_log_ratio(np.array([INITIAL_P2]))[0]
optimum_dual = convert_p2_to_log_ratio(np.array([TARGET_PMF[:2]]))[0]
axd.plot(*start_dual, **START_MARKER)
axd.plot(*optimum_dual, **OPTIMUM_MARKER)
axd.set_xlabel(r"$s'_1 = \log(p_1/p_3)$")
axd.set_ylabel(r"$s'_2 = \log(p_2/p_3)$")
axd.set_title(r"Dual Log-Ratio Space ($\mathbb{R}^2$) Overlay")
axd.set_xlim(-DUAL_LIM, DUAL_LIM * 2)
axd.set_ylim(-DUAL_LIM, DUAL_LIM * 2)
axd.legend(fontsize='small', loc='upper right')
axd.grid(True, which="both", ls=":", alpha=0.6)
plt.show()
output_file1 = "figure1_primal_loss.png"
output_file2 = "figure2_primal_dual.png"
fig1.savefig(output_file1, dpi=80, bbox_inches="tight")
fig2.savefig(output_file2, dpi=80, bbox_inches="tight")
print(f"Figure 1 saved at: {output_file1}")
print(f"Figure 2 saved at: {output_file2}")
logging.info("Script finished successfully.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment