Created
March 16, 2023 18:31
-
-
Save Edenhofer/ece9a2e3e8c67721dbdd706b3966f04c to your computer and use it in GitHub Desktop.
JAX enteres an infinite loop for trust-ncg minimization
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
from functools import partial | |
from typing import NamedTuple, Optional, Tuple, Union | |
import jax | |
from jax import lax | |
from jax import numpy as jnp | |
N_RESET = 20 | |
class CGResults(NamedTuple): | |
x: jnp.ndarray | |
nit: Union[int, jnp.ndarray] | |
nfev: Union[int, jnp.ndarray] # number of matrix-evaluations | |
info: Union[int, jnp.ndarray] | |
success: Union[bool, jnp.ndarray] | |
# The following is code adapted from Nicholas Mancuso to work with pytrees | |
class _QuadSubproblemResult(NamedTuple): | |
step: jnp.ndarray | |
hits_boundary: Union[bool, jnp.ndarray] | |
pred_f: Union[float, jnp.ndarray] | |
nit: Union[int, jnp.ndarray] | |
nfev: Union[int, jnp.ndarray] | |
njev: Union[int, jnp.ndarray] | |
nhev: Union[int, jnp.ndarray] | |
success: Union[bool, jnp.ndarray] | |
class _CGSteihaugState(NamedTuple): | |
z: jnp.ndarray | |
r: jnp.ndarray | |
d: jnp.ndarray | |
step: jnp.ndarray | |
energy: Union[None, float, jnp.ndarray] | |
hits_boundary: Union[bool, jnp.ndarray] | |
done: Union[bool, jnp.ndarray] | |
nit: Union[int, jnp.ndarray] | |
nhev: Union[int, jnp.ndarray] | |
def second_order_approx( | |
p: jnp.ndarray, | |
cur_val: Union[float, jnp.ndarray], | |
g: jnp.ndarray, | |
hessp_at_xk, | |
) -> Union[float, jnp.ndarray]: | |
return cur_val + jnp.vdot(g, p) + 0.5 * jnp.vdot(p, hessp_at_xk(p)) | |
def get_boundaries_intersections( | |
z: jnp.ndarray, d: jnp.ndarray, trust_radius: Union[float, jnp.ndarray] | |
): | |
a = jnp.vdot(d, d) | |
b = 2 * jnp.vdot(z, d) | |
c = jnp.vdot(z, z) - trust_radius**2 | |
sqrt_discriminant = jnp.sqrt(b * b - 4 * a * c) | |
aux = b + jnp.copysign(sqrt_discriminant, b) | |
ta = -aux / (2 * a) | |
tb = -2 * c / aux | |
ra = jnp.where(ta < tb, ta, tb) | |
rb = jnp.where(ta < tb, tb, ta) | |
return (ra, rb) | |
def _cg_steihaug_subproblem( | |
cur_val: Union[float, jnp.ndarray], | |
g: jnp.ndarray, | |
hessp_at_xk, | |
*, | |
trust_radius: Union[float, jnp.ndarray], | |
tr_norm_ord: Union[None, int, float, jnp.ndarray] = None, | |
resnorm: Optional[float], | |
absdelta: Optional[float] = None, | |
norm_ord: Union[None, int, float, jnp.ndarray] = None, | |
miniter: Union[None, int] = None, | |
maxiter: Union[None, int] = None, | |
) -> _QuadSubproblemResult: | |
from jax.experimental.host_callback import call | |
tr_norm_ord = jnp.inf if tr_norm_ord is None else tr_norm_ord # taken from JAX | |
norm_ord = 2 if norm_ord is None else norm_ord # TODO: change to 1 | |
maxiter_fallback = 20 * g.size # taken from SciPy's NewtonCG minimzer | |
miniter = jnp.minimum( | |
6, maxiter if maxiter is not None else maxiter_fallback | |
) if miniter is None else miniter | |
maxiter = jnp.maximum( | |
jnp.minimum(200, maxiter_fallback), miniter | |
) if maxiter is None else maxiter | |
common_dtp = g.dtype | |
eps = 6. * jnp.finfo(common_dtp).eps | |
# second-order Taylor series approximation at the current values, gradient, | |
# and hessian | |
soa = partial( | |
second_order_approx, cur_val=cur_val, g=g, hessp_at_xk=hessp_at_xk | |
) | |
# helpers for internal switches in the main CGSteihaug logic | |
def noop( | |
param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] | |
) -> _CGSteihaugState: | |
iterp, z_next = param | |
return iterp | |
def step1( | |
param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] | |
) -> _CGSteihaugState: | |
iterp, z_next = param | |
z, d, nhev = iterp.z, iterp.d, iterp.nhev | |
ta, tb = get_boundaries_intersections(z, d, trust_radius) | |
pa = z + ta * d | |
pb = z + tb * d | |
p_boundary = jnp.where(soa(pa) < soa(pb), pa, pb) | |
return iterp._replace( | |
step=p_boundary, nhev=nhev + 2, hits_boundary=True, done=True | |
) | |
def step2( | |
param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] | |
) -> _CGSteihaugState: | |
iterp, z_next = param | |
z, d = iterp.z, iterp.d | |
ta, tb = get_boundaries_intersections(z, d, trust_radius) | |
p_boundary = z + tb * d | |
return iterp._replace(step=p_boundary, hits_boundary=True, done=True) | |
def step3( | |
param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] | |
) -> _CGSteihaugState: | |
iterp, z_next = param | |
return iterp._replace(step=z_next, hits_boundary=False, done=True) | |
# initialize the step | |
p_origin = jnp.zeros_like(g) | |
# init the state for the first iteration | |
z = p_origin | |
r = g | |
d = -r | |
energy = 0. | |
init_param = _CGSteihaugState( | |
z=z, | |
r=r, | |
d=d, | |
step=p_origin, | |
energy=energy, | |
hits_boundary=False, | |
done=maxiter == 0, | |
nit=0, | |
nhev=0 | |
) | |
import jax | |
# Search for the min of the approximation of the objective function. | |
def body_f(iterp: _CGSteihaugState) -> _CGSteihaugState: | |
z, r, d = iterp.z, iterp.r, iterp.d | |
energy, nit = iterp.energy, iterp.nit | |
nit += 1 | |
jax.debug.print("in body {nit} \\\\ 1 ::", nit=nit) | |
Bd = hessp_at_xk(d) | |
dBd = jnp.vdot(d, Bd) | |
r_squared = jnp.vdot(r, r) | |
alpha = r_squared / dBd | |
z_next = z + alpha * d | |
r_next = r + alpha * Bd | |
r_next_squared = jnp.vdot(r_next, r_next) | |
beta_next = r_next_squared / r_squared | |
d_next = -r_next + beta_next * d | |
jax.debug.print("in body {nit} \\\\ 2 ::", nit=nit) | |
accept_z_next = nit >= maxiter | |
jax.debug.print( | |
"in body {nit} \\\\ 3 :: accept_z_next={accept_z_next}", | |
nit=nit, | |
accept_z_next=accept_z_next | |
) | |
if norm_ord == 2: | |
r_next_norm = jnp.sqrt(r_next_squared) | |
else: | |
r_next_norm = jnp.linalg.norm(r_next, ord=norm_ord) | |
accept_z_next |= r_next_norm < resnorm | |
# Relative to a plain CG, `z_next` is negative | |
energy_next = jnp.vdot((r_next + g) / 2, z_next) | |
energy_diff = energy - energy_next | |
if absdelta is not None: | |
neg_energy_eps = -eps * jnp.abs(energy) | |
accept_z_next |= (energy_diff >= neg_energy_eps | |
) & (energy_diff < absdelta) & (nit >= miniter) | |
jax.debug.print("in body {nit} \\\\ 4 ::", nit=nit) | |
# include a junk switch to catch the case where none should be executed | |
z_next_norm = jnp.linalg.norm(z_next, ord=tr_norm_ord) | |
jax.debug.print("in body {nit} \\\\ 5 :: pre-index", nit=nit) | |
index = jnp.argmax( | |
jnp.array( | |
[False, dBd <= 0, z_next_norm >= trust_radius, accept_z_next] | |
) | |
) | |
jax.debug.print("in body {nit} \\\\ 6 :: pre-switch {index}", nit=nit, index=index) | |
iterp = lax.switch(index, [noop, step1, step2, step3], (iterp, z_next)) | |
jax.debug.print("in body {nit} \\\\ 7 :: post-switch", nit=nit) | |
iterp = iterp._replace( | |
z=z_next, | |
r=r_next, | |
d=d_next, | |
energy=energy_next, | |
nhev=iterp.nhev + 1, | |
nit=nit | |
) | |
return iterp | |
def cond_f(iterp: _CGSteihaugState) -> bool: | |
jax.debug.print( | |
"cond_f={c} maxiter={maxiter}", c=~iterp.done, maxiter=maxiter | |
) | |
return jnp.logical_not(iterp.done) | |
# perform inner optimization to solve the constrained | |
# quadratic subproblem using cg | |
jax.debug.print("looped {result.done} {result}", result=init_param) | |
result = lax.while_loop(cond_f, body_f, init_param) | |
jax.debug.print("looped {result.done} {result}", result=result) | |
pred_f = soa(result.step) | |
result = _QuadSubproblemResult( | |
step=result.step, | |
hits_boundary=result.hits_boundary, | |
pred_f=pred_f, | |
nit=result.nit, | |
nfev=0, | |
njev=0, | |
nhev=result.nhev + 1, | |
success=True | |
) | |
return result | |
def rosenbrock(np): | |
def func(x): | |
return jnp.sum(100. * jnp.diff(x)**2 + (1. - x[:-1])**2) | |
return func | |
def himmelblau(np): | |
def func(p): | |
x, y = p | |
return (x**2 + y - 11.)**2 + (x + y**2 - 7.)**2 | |
return func | |
def matyas(np): | |
def func(p): | |
x, y = p | |
return 0.26 * (x**2 + y**2) - 0.48 * x * y | |
return func | |
def eggholder(np): | |
def func(p): | |
x, y = p | |
return -(y + 47.) * jnp.sin( | |
jnp.sqrt(jnp.abs(x / 2. + y + 47.)) | |
) - x * jnp.sin(jnp.sqrt(jnp.abs(x - (y + 47.)))) | |
return func | |
def hessp(primals, tangents): | |
return jax.jvp(jax.grad(fun), (primals, ), (tangents, ))[1] | |
fun = eggholder(jnp) | |
x0 = jnp.ones(2) * 100. | |
f0, g0 = jax.value_and_grad(fun)(x0) | |
kwargs = { | |
"absdelta": 0., | |
"resnorm": 0., | |
"trust_radius": 1., | |
"norm_ord": 1, | |
} | |
_cg_steihaug_subproblem(f0, g0, partial(hessp, x0), **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment