Last active
June 11, 2024 17:41
-
-
Save RicardoDominguez/f013d21a5991e863ffcf9076f5b9b34d to your computer and use it in GitHub Desktop.
BVP solver in JAX based on scipy.integrate.solve_bvp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Boundary value problem solver.""" | |
import jax | |
import jax.numpy as jnp | |
# ------------------------------------------------------------------------------------------ | |
# Linear solver for bordered almost block diagonal (BABD) systems | |
# ------------------------------------------------------------------------------------------ | |
# Implementation as described in [1] Section 2.1 (structural orthogonal factorization). | |
# [1] M. Dessole and F. Marcuzzi, "A massively parallel algorithm for Bordered Almost Block | |
# Diagonal Systems on GPUs", 2022. | |
def getQR(T, S): | |
""" Equation (9) """ | |
A = jnp.r_[T, S] # (2n, n) | |
Q, R = jnp.linalg.qr(A, mode='complete') | |
U = R[:R.shape[0] // 2] | |
return Q, U | |
def forward_reduce_sys(S, T): | |
""" Forward reduces until Equation (12), LHS """ | |
def body(carry, x): | |
W_hat, V_hat = carry | |
Si, Ti = x | |
n = Si.shape[-1] | |
Q, U = getQR(W_hat, Si) | |
V = Q[:n].T @ V_hat | |
W = Q[n:].T @ Ti | |
carry = (W[n:], V[n:]) | |
y = (Q, [U, V[:n], W[:n]]) | |
return carry, y | |
init = (T[0], S[0]) | |
xs = (S[1:], T[1:]) | |
return jax.lax.scan(body, init, xs) | |
def forward_reduce_b(Q, b): | |
""" Forward reduces until Equation (12), RHS using the | |
QR decompositions computed from `forward_reduce_sys`. | |
""" | |
def body(f_hat, x): | |
Q, bi = x | |
f = Q.T @ jnp.r_[f_hat, bi] | |
n = f_hat.shape[0] | |
return f[n:], f[:n] | |
init = b[0] | |
xs = (Q, b[1:]) | |
return jax.lax.scan(body, init, xs) | |
def back_substitute(params, x0, xn): | |
""" Recursively solves Equation (14) """ | |
def body(last_x, params): | |
U, V, W, f = params | |
b = f - V @ x0 - W @ last_x | |
x = jax.scipy.linalg.solve_triangular(U, b) | |
return x, x | |
_, x = jax.lax.scan(body, xn, params, reverse=True) | |
return x | |
def BABD_factor(S, T, Ba, Bb): | |
"""Factor the BABD system until its reduced form system. | |
The entire system is reduced as it is (i.e., only one slice, P=1) | |
since I found no computational advantage to vmapping the forward | |
reduction procedure over more slices P > 1. | |
Returns `(lu_dec, (Q, back_params))`, where `lu_dec` is the LU | |
decomposition of the reduced system, `Q` is the set of QR | |
decompositions used to reduce the system (needed when reducing | |
the vector `b` in `forward_reduce_b`), and `back_params` is used | |
for the back subsitution in `back_substitute` when solving. | |
Parameters | |
---------- | |
S : ndarray, shape (m, n, n) | |
Block diagonal. | |
T : ndarray, shape (m, n, n) | |
Block off-diagonal. | |
Ba, Bb : ndarray, shape (n, n) | |
Blocks at the last block row, first and last block columns. | |
References | |
---------- | |
.. [1] M. Dessole and F. Marcuzzi, "A massively parallel algorithm | |
for Bordered Almost Block Diagonal Systems on GPUs", 2022. | |
""" | |
# Reduce the system so that only (2n x 2n) has to be solved | |
(Tr, Sr), others = forward_reduce_sys(S, T) | |
# Cholesky decomposition of the reduced system | |
M = jnp.r_[jnp.c_[Sr, Tr], jnp.c_[Ba, Bb]] | |
lu_dec = jax.scipy.linalg.lu_factor(M) | |
return lu_dec, others | |
def BABD_solve(factorization, b): | |
""" Solve the BABD system Ax=b, given the factorization of A. | |
Parameters | |
---------- | |
factorization : tuple | |
As returned by `BABD_factor`. | |
b : ndarray, shape (m*n, ) or (m, n) | |
Values for which to solve the BABD system Ax=b. | |
""" | |
lu_dec, (Q, back_params) = factorization | |
b = b.reshape(-1, Q.shape[-1] // 2) | |
# Apply the same transformations to b | |
br, f = forward_reduce_b(Q, b[:-1]) | |
# Solve the reduced system | |
br = jnp.r_[br, b[-1]] | |
x_r = jax.scipy.linalg.lu_solve(lu_dec, br) | |
x0 = x_r[:x_r.shape[0] // 2] | |
xn = x_r[x_r.shape[0] // 2:] | |
x_m = back_substitute(back_params + [f], x0, xn) | |
x = jnp.r_[x0, x_m.reshape(-1), xn] | |
return x | |
# ------------------------------------------------------------------------------------------ | |
# Functions to compute the sparse Jacobian of the collocation system | |
# ------------------------------------------------------------------------------------------ | |
# Naturally, significantly more efficent that jax.jacfwd on the collocation objective. | |
def construct_jac(h, df_dy, df_dy_middle, dbc_dya, dbc_dyb): | |
"""Construct the Jacobian of the collocation system. | |
There are m * d functions: m - 1 collocations residuals, each | |
containing d components, followed by d boundary condition residuals. | |
There are m * d variables: m vectors of y, each containing d components. | |
For example, let m = 4, and d = 2 then the Jacobian will have | |
the following sparsity structure, named bordered almost block diagonal (BABD): | |
1 1 2 2 0 0 0 0 | |
1 1 2 2 0 0 0 0 | |
0 0 1 1 2 2 0 0 | |
0 0 1 1 2 2 0 0 | |
0 0 0 0 1 1 2 2 | |
0 0 0 0 1 1 2 2 | |
3 3 0 0 0 0 4 4 | |
3 3 0 0 0 0 4 4 | |
Zeros denote identically zero values, other values denote different kinds | |
of blocks in the matrix (see below). The blank row indicates the separation | |
of collocation residuals from boundary conditions. And the blank column | |
indicates the separation of y values from p values. | |
Refer to [1]_ (p. 306) for the formula of n x n blocks for derivatives | |
of collocation residuals with respect to y. | |
We computed the structured orthogonal factorization of the BABD system as | |
described in [2]_ (Section 2.1). This is much more efficient that the linear | |
solvers provided by jax, and competitive with the sparse LU decomposition | |
used in Scipy's implementation. | |
Parameters | |
---------- | |
df_dy : ndarray, shape (m, n, n) | |
Jacobian of f with respect to y computed at the mesh nodes. | |
Corresponds to the block diagonal (represented as 1s above). | |
df_dy_middle : ndarray, shape (m-1, d, d) | |
Jacobian of f with respect to y computed at the middle between the mesh nodes. | |
Coresponds to the block off-diagonal (represented as 2s above). | |
dbc_dya, dbc_dyb : ndarray, shape (d, d) | |
Jacobian of bc with respect to ya and yb. | |
Corresponds to the last block row, represented as (3s and 4s above). | |
Returns | |
------- | |
factorization : tuple | |
Strcutred orthogonal factorization of the BABD system. See `BABD_factor`. | |
References | |
---------- | |
.. [1] J. Kierzenka, L. F. Shampine, "A BVP Solver Based on Residual | |
Control and the Maltab PSE", ACM Trans. Math. Softw., Vol. 27, | |
Number 3, pp. 299-316, 2001. | |
.. [2] M. Dessole and F. Marcuzzi, "A massively parallel algorithm for Bordered | |
Almost Block Diagonal Systems on GPUs", 2022. | |
""" | |
m, n, _ = df_dy.shape | |
h = h[:, jnp.newaxis, jnp.newaxis] # (m-1, 1, 1) | |
# Computing diagonal d x d blocks. | |
dPhi_dy_0 = -jnp.identity(n)[jnp.newaxis].repeat(m - 1, axis=0) | |
dPhi_dy_0 = dPhi_dy_0 - h / 6 * (df_dy[:-1] + 2 * df_dy_middle) | |
T = df_dy_middle @ df_dy[:-1] | |
dPhi_dy_0 = dPhi_dy_0 - h ** 2 / 12 * T | |
# Computing off-diagonal d x d blocks. | |
dPhi_dy_1 = jnp.identity(n)[jnp.newaxis].repeat(m - 1, axis=0) | |
dPhi_dy_1 = dPhi_dy_1 - h / 6 * (df_dy[1:] + 2 * df_dy_middle) | |
T = df_dy_middle @ df_dy[1:] | |
dPhi_dy_1 = dPhi_dy_1 + h ** 2 / 12 * T | |
return BABD_factor(dPhi_dy_0, dPhi_dy_1, dbc_dya, dbc_dyb) | |
def prepare_jac(fun, bc): | |
"""Returns a function which evaluates the Jacobian of the collocation system.""" | |
fun_jac = jax.vmap(jax.jacfwd(fun, argnums=1)) | |
bc_jac = jax.jacfwd(bc, argnums=([0, 1])) | |
def sys_jac(x, h, y, y_middle): | |
"""Evaluates the Jacobian of the collocation system. | |
Note that by requiring y_middle as an argument, the system must be evaluated | |
one additional time at the start of the Newton solve. | |
Returns a factorization of the Jacobian as detailed in `construct_jac`. | |
Parameters | |
--------- | |
x : ndarray, shape (m, ) | |
Nodes of the mesh. | |
h : ndarray, shape (m-1, ) | |
Increment between mesh nodes (i.e., np.diff(x)). | |
y : ndarray, shape (m, n) | |
Solution values at the mesh nodes. | |
y_middle : ndarray, shape (m-1, n) | |
Solution values at the mid-points of each mesh interval. | |
""" | |
m = x.shape[0] | |
# Compute the derivatives at the mesh points and its midpoints | |
x_middle = x[:-1] + 0.5 * h | |
x_xm = jnp.r_[x, x_middle] | |
y_ym = jnp.r_[y, y_middle] | |
df_dyym = fun_jac(x_xm, y_ym) | |
dbc_dya, dbc_dyb = bc_jac(y[0], y[-1]) | |
return construct_jac(h, df_dyym[:m], df_dyym[m:], dbc_dya, dbc_dyb) | |
return sys_jac | |
# ------------------------------------------------------------------------------------------ | |
# Functions to evaluate the collocation residuals | |
# ------------------------------------------------------------------------------------------ | |
def compute_ymiddle(fun, y, x, h): | |
"""Evaluate the solution at the middle points of the mesh intervals. | |
Note that the solution to the BVP is sought as a cubic C1 continous spline with | |
derivatives matching the ODE rhs at given nodes `x`. | |
This function is needed to evaluate the Jacobian of the collocation system. | |
For the Parameters and Returns, see `collocation_fun`. | |
""" | |
h = h[:, jnp.newaxis] | |
f = jax.vmap(fun)(x, y) | |
y_middle = (0.5 * (y[1:] + y[:-1]) - 0.125 * h * (f[1:] - f[:-1])) | |
return y_middle, f | |
def collocation_fun(fun, y, x, h): | |
"""Evaluate collocation residuals. | |
The solution to the BVP is sought as a cubic C1 continuous spline with | |
derivatives matching the ODE rhs at given nodes `x`. Collocation conditions | |
are formed from the equality of the spline derivatives and rhs of the ODE | |
system in the middle points between nodes. | |
Such method is classified to Lobbato IIIA family in ODE literature. | |
Refer to [1]_ for the formula and some discussion. | |
Parameters | |
---------- | |
fun : callable | |
Right-hand side of the system. The calling signature is ``fun(x, y)``. | |
All arguments are ndarray: ``x`` with shape (,), ``y`` with shape (n,). | |
The return value must be an array with shape (n,). | |
y : ndarray, shape (m, n) | |
Solution values at the mesh nodes. | |
x : ndarray, shape (m,) | |
Nodes of the mesh. | |
h : ndarray with shape (m-1,) | |
Increment between the mesh nodes, that is, jnp.diff(x). | |
Returns | |
------- | |
col_res : ndarray, shape (m-1, n) | |
Collocation residuals at the middle points of the mesh intervals. | |
y_middle : ndarray, shape (m-1, n) | |
Values of the cubic spline evaluated at the middle points of the mesh intervals. | |
f : ndarray, shape (m, n) | |
RHS of the ODE system evaluated at the mesh nodes. | |
f_middle : ndarray, shape (m-1, n) | |
RHS of the ODE system evaluated at the middle points of the mesh | |
intervals (and using `y_middle`). | |
References | |
---------- | |
.. [1] J. Kierzenka, L. F. Shampine, "A BVP Solver Based on Residual | |
Control and the Maltab PSE", ACM Trans. Math. Softw., Vol. 27, | |
Number 3, pp. 299-316, 2001. | |
""" | |
y_middle, f = compute_ymiddle(fun, y, x, h) | |
h = h[:, jnp.newaxis] | |
f_middle = jax.vmap(fun)(x[:-1] + 0.5 * h, y_middle) | |
col_res = y[1:] - y[:-1] - h / 6 * (f[:-1] + f[1:] + 4 * f_middle) | |
return col_res, (y_middle, f, f_middle) | |
# ------------------------------------------------------------------------------------------ | |
# Functions to fit cubic splines | |
# ------------------------------------------------------------------------------------------ | |
# The solution to the BVP is sought as a cubic spline. | |
def fit_cubic_spline_coeffs(y, yp, h): | |
""" Fit the parameters of a cubic spline. | |
Formulas for the coefficients are taken from scipy.interpolate.CubicSpline. | |
Parameters | |
--------- | |
y : ndarray, shape (m, n) | |
Solution values at the mesh nodes. | |
yp : ndarray, shape (m, n) | |
ODE system evaluated at the mesh nodes. | |
h : ndarray, shape (m-1, ) | |
Increment between mesh nodes, that is, jnp.diff(x) | |
Returns | |
------- | |
c : ndarray with shape (m-1, n, 4) | |
""" | |
h = h[:, jnp.newaxis] | |
slope = (y[1:] - y[:-1]) / h | |
t = (yp[:-1] + yp[1:] - 2 * slope) / h | |
c0 = y[:-1] | |
c1 = yp[:-1] | |
c2 = (slope - yp[:-1]) / h - t | |
c3 = t / h | |
return jnp.stack([c0, c1, c2, c3], axis=-1) | |
def eval_cubic_spline(x, coeffs, xs): | |
""" Evaluate a cubic spline and its first derivative at ``xs``. | |
Parameters | |
---------- | |
x : ndarray with shape (m, ) | |
Nodes of the spline. | |
coeffs : ndarray with shape (m, n, 4) | |
Coefficients of the spline, as returned by `fit_cubic_spline_coeffs`. | |
xs : ndarray with shape (t, ) | |
Where to evaluate the spline. | |
Returns | |
------- | |
c : ndarray with shape (t, n) | |
Evaluation of the spline at ``xs``. | |
dc : ndarray with shape (t, n) | |
First derivative of the spline at ``xs``. | |
""" | |
ind = jnp.digitize(xs, x) - 1 # determine the interval in x | |
ind = jnp.clip(ind, 0, len(x) - 2) # include the right endpoint | |
c = coeffs[ind] # use the relevant spline coefficients | |
t = (xs - x[ind])[:, jnp.newaxis, jnp.newaxis] | |
t_powers = jnp.power(t, jnp.arange(4)) | |
eval = jnp.sum(c * t_powers, axis=-1) | |
deval = jnp.sum(c[..., 1:] * t_powers[..., :-1] | |
* jnp.array([[[1, 2, 3.]]]), axis=-1) | |
return eval, deval | |
# ------------------------------------------------------------------------------------------ | |
# Damped Newton method | |
# ------------------------------------------------------------------------------------------ | |
class BacktrackLineSearch: | |
def __init__(self, sigma=0.2, tau=0.5, n_trials=4, jit=False): | |
""" Backtracking line search. | |
Parameters | |
---------- | |
sigma : float | |
Minimum relative improvement of the criterion function to accept the | |
step (Armijo constant). | |
tau : float | |
Step size decrease factor for backtracking. | |
n_trials : int | |
Maximum number of backtracking steps, the minimum step is then tau ** n_trial. | |
jit : bool | |
Whether to jit-compile the optimization loop (as a jax.lax.while_loop) | |
""" | |
def _run(cost_fnc, y, step, cost, init_extras): | |
""" Iteratively reduces step size until cost is sufficiently decreased. | |
c(y - alpha * step) < (1 - 2 * alpha * sigma) c(y) | |
Parameters | |
---------- | |
cost_fnc : callabe s.t. cost, extras = cost_fnc(y) | |
Cost function taking as input some `y`, and returning both some | |
float `cost` and some `extra` arguments. | |
y : ndarray, shape (m, n) | |
Base point from which steps are taken. | |
step : ndarray, shape (m, n) | |
Direction for the updates. | |
cost : float | |
The cost function `cost_fnc` evaluated at the base point `y`. | |
init_extras : tuple | |
Intitialization of the extra parameters returned by the cost function, | |
since we assume cost, extras = cost_fcn(y). Required for jax.lax.while_loop. | |
Returns | |
------- | |
iterations : int | |
Number of backtracking steps taken. | |
y_new : ndarray, shape (m, n) | |
Point that meets the backtracking critetion (hopefully) | |
cost_new, extras = cost_fnc(y_new) | |
""" | |
# Conditions for termination, either | |
# - the backtracking condition is fulfilled (cost sufficiently low) | |
# - the maximum number of steps is taken | |
def keep_going(val): | |
iteration, (_, cost_new, _) = val | |
alpha = tau ** (iteration - 1) | |
iters_low = iteration < n_trials + 1 | |
cost_high = cost_new >= (1 - 2 * alpha * sigma) * cost | |
return iters_low & cost_high | |
def backtrack_step(val): | |
iteration, _ = val | |
# Update solution value | |
alpha = tau ** iteration | |
y_new = y - alpha * step | |
# Compute the new cost | |
cost_new, extras = cost_fnc(y_new) | |
return iteration + 1, (y_new, cost_new, extras) | |
# Hard code the shape of the extra returns, this could be better | |
m, n = y.shape | |
val = 0, (jnp.zeros_like(y), jnp.array(jnp.inf), init_extras) | |
if jit: | |
val = jax.lax.while_loop(keep_going, backtrack_step, val) | |
else: | |
while keep_going(val): | |
val = backtrack_step(val) | |
return val | |
if jit: | |
self.run = jax.jit(_run, static_argnums=0) | |
else: | |
self.run = _run | |
class BacktrackingNewton: | |
def __init__(self, col_obj, get_jac, init_extras, bvp_tol, bc_tol, max_njev=4, | |
max_iter=8, jit=False): | |
""" Simple Newton method with a backtracking line search. | |
As advised in [1]_, an affine-invariant criterion function F = ||J^-1 r||^2 | |
is used, where J is the Jacobian matrix at the current iteration and r is | |
the vector of collocation residuals (values of the system lhs). | |
The method alters between full Newton iterations and fixed-Jacobian | |
iterations. The Jacobian is recomputed if a full Newton step does not | |
meet some backtracking criterion, otherwise the same Jacobian is reused. | |
There are other tricks proposed in [1]_, but they are not used as they | |
don't seem to improve anything significantly, and even break the | |
convergence on some test problems I tried. | |
Parameters | |
---------- | |
col_obj : callable | |
Function computing collocation residuals, and some other extra | |
parameters such that (res, extra) = col_obj(y, x, h) | |
get_jac : callable | |
Returns the Jacobian of the collocation objective w.r.t `y` and | |
evaluated at the mesh points `x`. | |
init_extras : callable | |
Takes the same arguments as `col_obj` and returns initial values | |
for the extra parameters returned by `col_obj`. | |
bvp_tol : float | |
Tolerance to which we want to solve the BVP. | |
bc_tol : float | |
Tolerance to which we want to satisfy the boundary conditions. | |
max_njev : int | |
Maximum allowed number of Jacobian evaluation and factorization, in | |
other words, the maximum number of full Newton iterations. A small | |
value is recommended in the literature. | |
max_iter : int | |
Maximum number of iterations, considering that some of them can be | |
performed with the fixed Jacobian. | |
jit : bool | |
Whether to jit-compile the relevant loops. | |
Returns | |
------- | |
y : ndarray, shape (m, n) | |
Final iterate for the function values at the mesh nodes. | |
res, extras = col_obj(y, x, h) | |
References | |
---------- | |
.. [1] U. Ascher, R. Mattheij and R. Russell "Numerical Solution of | |
Boundary Value Problems for Ordinary Differential Equations" | |
""" | |
def _loop(cond, body, val): | |
if jit: | |
return jax.lax.while_loop(cond, body, val) | |
else: | |
while cond(val): | |
val = body(val) | |
return val | |
line_search = BacktrackLineSearch(jit=jit) | |
def solve(y, x, h): | |
m, n = y.shape | |
# Some initialization of the extra parameters returned by `col_obj` and | |
# `backtrack_cost`. Required for jax.lax.do_loop. Must have consistent dims. | |
colobj_extras = (jnp.zeros((m - 1, n)), jnp.zeros((m, n)), jnp.zeros((m - 1) * n)) | |
btrack_extras = (jnp.zeros(m * n), colobj_extras, jnp.zeros((m, n))) | |
# We know that the solution residuals at the middle points of the mesh | |
# are connected with collocation residuals r_middle = 1.5 * col_res / h. | |
# As our BVP solver tries to decrease relative residuals below a certain | |
# tolerance, it seems reasonable to terminated Newton iterations by | |
# comparison of r_middle / (1 + jnp.abs(f_middle)) with a certain threshold, | |
# which we choose to be 1.5 orders lower than the BVP tolerance. We rewrite | |
# the condition as col_res < tol_r * (1 + jnp.abs(f_middle)), then tol_r | |
# should be computed as follows: | |
tol_r = 2 / 3 * h[:, jnp.newaxis] * 5e-2 * bvp_tol | |
def continue_newton(val): | |
(iteration, njev, terminate_tol), _, _ = val | |
return (iteration < max_iter) & (njev <= max_njev) & (~terminate_tol) | |
# At each outer step the Jacobian is recomputed, then a number of inner Newton | |
# steps are taken using the same Jacobian, as long as the full Newton step meets | |
# the backtracking condition (i.e., the affine-invariant criterion). If in the | |
# other hand the step size is reduced in order to meet the backtracking condition, | |
# then the Jacobian is recomputed. | |
def newton_step(val): | |
(iteration, njev, _), (y, _), extras = val | |
# Recompute Jacobian | |
y_middle, _, _ = extras | |
jac = get_jac(x, h, y, y_middle) | |
# The cost function is the affine-invariant criterion F = ||J^-1 r||^2, where | |
# J is the Jacobian and r are the collocation residuals. | |
def backtrack_cost(y): | |
res, extras = col_obj(y, x, h) # compute collocation residuals | |
# Compute new step | |
step = BABD_solve(jac, res).reshape(y.shape) | |
cost = jnp.sum(step ** 2) | |
return cost, (res, extras, step) | |
# Continue taking steps with the same Jacobian as long as the step size does | |
# not need to be decreased to meet the affine-invariant criterion. | |
def do_step(val): | |
(iteration, _, _), y_step_cost, _ = val | |
n_trials, (y, cost, extras) = line_search.run(backtrack_cost, | |
*y_step_cost, btrack_extras) | |
res, extras, step = extras | |
_, _, f_middle = extras | |
col_res_cond = jnp.all(jnp.abs(res[:-n]) < tol_r * (1 + jnp.abs(f_middle))) | |
bc_res_cond = jnp.all(jnp.abs(res[-n:]) < bc_tol) | |
terminate_tol = col_res_cond & bc_res_cond | |
return (iteration + 1, n_trials, terminate_tol), (y, step, cost), (res, extras) | |
def continue_stepping(val): | |
(iteration, n_trials, terminate_tol), _, _ = val | |
return (iteration < max_iter) & (n_trials <= 1) & (~terminate_tol) | |
# Init step and cost | |
cost, (_, _, step) = backtrack_cost(y) | |
# Inner fixed-Jacobian loop | |
res, extras = jnp.zeros(m * n), colobj_extras | |
val = ((iteration, 0, jnp.array(False)), (y, step, cost), (res, extras)) | |
val = _loop(continue_stepping, do_step, val) | |
(iteration, _, terminate_tol), (y, _, _), (res, extras) = val | |
return (iteration, njev + 1, terminate_tol), (y, res), extras | |
# Outer loop where the Jacobian is recomputed | |
extras = init_extras(y, x, h) | |
val = ((0, 0, jnp.array(False)), (y, jnp.zeros(m*n)), extras) | |
(iterations, njevs, _), (y, res), extras = _loop(continue_newton, newton_step, val) | |
return y, (res, extras) | |
self.solve = solve | |
# ------------------------------------------------------------------------------------------ | |
# Functions to estimate the relative residuals of the approximate solution | |
# ------------------------------------------------------------------------------------------ | |
# 5-point Lobatto quadrature provides much more accurate estimatations of the relative | |
# residuals compared to Simpson's rule (3-point Lobatto quadrature), however it requires | |
# ~2*m additional evaluations of the ODE function, where m is the number of mesh points. | |
def estimate_rms_residuals_5lobatto(fun, coeffs, x, h, r_middle, f_middle): | |
"""Estimate rms values of collocation residuals using 5-point Lobatto quadrature. | |
The residuals are defined as the difference between the derivatives of | |
our solution and rhs of the ODE system. We use relative residuals, i.e., | |
normalized by 1 + jnp.abs(f). RMS values are computed as sqrt from the | |
normalized integrals of the squared relative residuals over each interval. | |
Integrals are estimated using 5-point Lobatto quadrature [1]_, we use the | |
fact that residuals at the mesh nodes are identically zero. | |
In [2] they don't normalize integrals by interval lengths, which gives | |
a higher rate of convergence of the residuals by the factor of h**0.5. | |
I chose to do such normalization for an ease of interpretation of return | |
values as RMS estimates. | |
Parameters | |
---------- | |
fun : callable | |
ODE function. | |
coeffs : ndarray, shape (m-1, n, 4) | |
Coefficients of a cubic spline parametrizing the approximate solution. | |
x : ndarray, shape (m, n) | |
Nodes of the mesh. | |
h : ndarray, shape (m-1, ) | |
Interval between the mesh nodes (i.e., np.diff(x)). | |
r_middle : ndarray, shape (m-1, d) | |
Residuals at the mid-point of each mesh interval. | |
f_middle : ndarray, shape (m-1, d) | |
Evaluation of the ODE function at the mid-point of each mesh interval. | |
Returns | |
------- | |
rms_res : ndarray, shape (m-1,) | |
Estimated rms values of the relative residuals over each mesh interval. | |
References | |
---------- | |
.. [1] http://mathworld.wolfram.com/LobattoQuadrature.html | |
.. [2] J. Kierzenka, L. F. Shampine, "A BVP Solver Based on Residual | |
Control and the Maltab PSE", ACM Trans. Math. Softw., Vol. 27, | |
Number 3, pp. 299-316, 2001. | |
""" | |
x_middle = x[:-1] + 0.5 * h | |
s = 0.5 * h * (3 / 7) ** 0.5 | |
xx = jnp.r_[x_middle + s, x_middle - s] | |
y, yp = eval_cubic_spline(x, coeffs, xx) | |
f = jax.vmap(fun)(xx, y) | |
r = yp - f | |
r /= 1 + jnp.abs(f) | |
r_middle /= 1 + jnp.abs(f_middle) | |
r = jnp.sum(r ** 2, axis=-1) | |
r_middle = jnp.sum(r_middle ** 2, axis=-1) | |
r_sum = r[:s.shape[0]] + r[s.shape[0]:] | |
return (0.5 * (32 / 45 * r_middle + 49 / 90 * r_sum)) ** 0.5 | |
def estimate_rms_residuals_simpson(r_middle, f_middle): | |
"""Estimate rms values of collocation residuals using Simpsons rule. | |
The residuals are defined as the difference between the derivatives of | |
our solution and rhs of the ODE system. We use relative residuals, i.e., | |
normalized by 1 + jnp.abs(f). RMS values are computed as sqrt from the | |
normalized integrals of the squared relative residuals over each interval. | |
Since this is precisely what the collocation objective solves for, | |
using 5-point Lobatto quadrature gives significatly more accurate | |
estimations of the residuals. However, 5-point Lobatto quadrature | |
requires ~ 2*m more evaluations of the ODE function. | |
Parameters | |
---------- | |
r_middle : ndarray, shape (m-1, d) | |
Residuals at the mid-point of each mesh interval. | |
f_middle : ndarray, shape (m-1, d) | |
Evaluation of the ODE function at the mid-point of each mesh interval. | |
Returns | |
------- | |
rms_res : ndarray, shape (m - 1,) | |
Estimated rms values of the relative residuals over each interval. | |
""" | |
# We use that residuals at the mesh nodes are identically zero. | |
r_middle /= 1 + jnp.abs(f_middle) | |
r_middle = jnp.sum(r_middle ** 2, axis=-1) | |
return ((2 / 3.) * r_middle) ** 0.5 | |
# ------------------------------------------------------------------------------------------ | |
# Functions implementing the mesh selection strategy | |
# ------------------------------------------------------------------------------------------ | |
# In scipy's bvp_solver implementation, the mesh is iteratively refined using a local | |
# criterion based on the estimated relative residuals of the approximate solution. However, | |
# JAX has poor support of arrays with dynamic size, and an efficient implementation of the | |
# refined strategy that is jittable and vmappable is non-trivial. We therefore use a global | |
# mesh selection strategy where node points are equidistributed based w.r.t. some picewise | |
# constant monitor function, without adding or removing any mesh nodes. This requires | |
# that are relatively fine initial mesh is used, and for the ODE function to be | |
# sufficiently smooth. The authors of MATLAB bvp4c [1] find global mesh selection strategies | |
# to be inferior to local strategies. Similarly, [2] Chapter 9.5 further argue that mesh | |
# refinement tends to be beneficial since a coarse initial mesh can be used, thus resulting | |
# in substantially more computationally effective solutions. | |
# [1] J. Kierzenka, L. F. Shampine, "A BVP Solver Based on Residual Control and the Maltab | |
# PSE", ACM Trans. Math. Softw., Vol. 27, Number 3, pp. 299-316, 2001. | |
# [2] U. Ascher, R. Mattheij and R. Russell "Numerical Solution of Boundary Value Problems | |
# for Ordinary Differential Equations". | |
def monitor_fourth_derivative(x, h, coeffs): | |
""" Approximates the fourth derivative of the BVP solution. | |
The BVP solution is parametrized by a cubic spline. [1]_ Chapter 9.3.1 argues that | |
higher order derivatives are more robust monitor function that relative residuals. | |
According to the description of [1]_, I am unsure whether the third or fourth | |
derivative should be used for the monitor function. According to preliminary tests | |
I find the fourth derivative to result in solutions with lower estimated residuals. | |
As described in [1]_, the fourth derivative of the BVP solution is approximated by | |
fitting a picewise linear function v(x) of the third derivative f the cubic spline | |
at the subinterval midpoints. The monitor function is then the derivative v'(x). | |
Parameters | |
---------- | |
x : ndarray, shape (m, 4) | |
Mesh nodes. | |
h : ndarray, shape (m-1, 4) | |
Interval between the mesh nodes, that is, np.diff(x). | |
coeffs : ndarray, shape (m-1, n, 4) | |
Coefficients of the Cubic spline. | |
Returns | |
------- | |
x_monitor: ndarray, shape (m+1, 4) | |
Points of the mesh at which the monitor function is evaluated. | |
monitor : ndarray, shape (m, ) | |
Evaluation of picewise monitor function at x_monitor. | |
References | |
---------- | |
.. [1] U. Ascher, R. Mattheij and R. Russell "Numerical Solution of | |
Boundary Value Problems for Ordinary Differential Equations" | |
""" | |
# Define the monitor function at the extreme points of the mesh and at the | |
# midpoint of each subinterval | |
midpoints = x[:-1] + h / 2. | |
x_monitor = jnp.r_[x[0], midpoints, x[-1]] # (m+1, ) | |
# Third derivative of the cubic spline | |
third_der = coeffs[..., -1] * 6 | |
# [1] does not specify the value that v(x) should have at the extreme of the | |
# mesh, so we assume that it is zero (which in general will not be the case). | |
pad_zeros = jnp.zeros((1, third_der.shape[-1])) | |
diffs = jnp.diff(jnp.r_[pad_zeros, third_der, pad_zeros], axis=0) # (m, n) | |
v_prime = diffs / jnp.diff(x_monitor)[:, None] | |
# Evaluate monitor function according to (9.17). I believe that the max norm | |
# should be used (refer to page 363). | |
monitor = jnp.max(jnp.abs(v_prime), axis=-1) ** (1 / 4.) | |
return x_monitor, monitor | |
def monitor_mazzia(x, coeffs, weight_infty=1., weight_1=1.): | |
""" Combines the L_infinity norm of the BVP solution and the L1 of the derivatives. | |
Monitor function described in [1]_ (p. 562), which is a linear combination of | |
the L_infinity norm of the approximate solution and the L1 norm of the | |
derivatives of the approximate solution. | |
In my experience, `monitor_fourth_derivative` tends to work significantly better, | |
however here we implement a very simplified version of the monitor function | |
described in [1]_. | |
Parameters | |
---------- | |
x : ndarray, shape (m, 4) | |
Mesh nodes. | |
coeffs : ndarray, shape (m-1, n, 4) | |
Coefficients of a cubic spline parametrizing the approximate solution. | |
weight_infty : float | |
Weight given to the infinity norm of the approximate solution. | |
weight_1 : float | |
Weight given to the L1 norm of the derivatives of the approximate solution. | |
Returns | |
------- | |
x_monitor: ndarray, shape (m+1, 4) | |
Points of the mesh at which the monitor function is evaluated. | |
monitor : ndarray, shape (m, ) | |
Evaluation of the monitor function at `x_monitor`. | |
References | |
---------- | |
.. [1] F. Mazzia, Mesh selection strategies of the code TOM for Boundary | |
Value Problems, 2022. | |
""" | |
y, yp = eval_cubic_spline(x, coeffs, x) | |
norm_yp = jnp.sum(jnp.abs(yp), axis=-1) | |
m_yp = norm_yp[:-1] + norm_yp[1:] | |
norm_y = jnp.max(jnp.abs(y), axis=-1) | |
m_y = jnp.abs(norm_y[1:] - norm_y[:-1]) | |
monitor = weight_infty * m_y + weight_1 * m_yp | |
return x, monitor | |
def equidistribute_mesh(x_monitor, monitor, n_points): | |
""" Solve for a new mesh by equidistributing a picewise constant monitor function. | |
A picewise constant monitor function is used for simplicity and computational | |
efficiency, since integration and reverse interpolation is then trivial [1]_. | |
Parameters | |
---------- | |
x_monitor : ndarray, shape (m,) | |
Mesh nodes at which the monitor function is evaluated. | |
monitor: ndarray, shape (m-1,) | |
Picewise constant monitor function, evaluated at x_monitor. | |
n_points: int | |
Number of mesh points in the new mesh. | |
Returns | |
------- | |
x_new : ndarray, shape (n_points, ) | |
New mesh nodes. | |
References | |
---------- | |
.. [1] U. Ascher, R. Mattheij and R. Russell "Numerical Solution of | |
Boundary Value Problems for Ordinary Differential Equations" | |
""" | |
# Integrate the monitor function from x[0] to x[-1] | |
integral = jnp.r_[jnp.array([0]), jnp.cumsum(jnp.diff(x_monitor) * monitor)] # (9.18) | |
# Equidistribute across the N intervals | |
intervals = jnp.linspace(0, integral[-1], n_points) # (9.19b) | |
# Inverse interpolation (9.19) for a picewise constant function | |
boxes = jnp.clip(jnp.digitize(intervals, integral) - 1, 0, n_points - 1) | |
increment_needed = intervals - integral[boxes] | |
new_x = x_monitor[boxes] + increment_needed / monitor[boxes] | |
return new_x | |
def print_iteration_header(): | |
print("{:^15}{:^15}{:^15}{:^15}{:^15}".format( | |
"Iteration", "Max residual", "Max BC residual", "Total nodes", | |
"Nodes added")) | |
def print_iteration_progress(iteration, residual, bc_residual, total_nodes, | |
nodes_added): | |
print("{:^15}{:^15.2e}{:^15.2e}{:^15}{:^15}".format( | |
iteration, residual, bc_residual, total_nodes, nodes_added)) | |
def solve_bvp(fun, bc, x, y, tol=1e-3, bc_tol=None, max_iterations=10, | |
min_improv=0.05, eval_lobatto=True, verbose=False, jit=True): | |
"""Solve a boundary value problem for a system of ODEs. | |
This function numerically solves a first order system of ODEs subject to | |
two-point boundary conditions:: | |
dy / dx = f(x, y, p), a <= x <= b | |
bc(y(a), y(b)) = 0 | |
Here x is a 1-D independent variable, y(x) is an N-D vector-valued function. | |
For the problem to be determined, there must be n boundary conditions, i.e., | |
bc must be an n-D function. | |
Parameters | |
---------- | |
fun : callable | |
Right-hand side of the system. The calling signature is ``fun(x, y)``. | |
All arguments are ndarray: ``x`` with shape (,), ``y`` with shape (d,). | |
The return value must be an array with shape (d,). | |
bc : callable | |
Function evaluating residuals of the boundary conditions. The calling | |
signature is ``bc(ya, yb)``. All arguments are ndarray: ``ya`` and | |
``yb`` with shape (n,). The return value must be an array with shape (n,). | |
x : array_like, shape (m,) | |
Initial mesh. Must be a strictly increasing sequence of real numbers | |
with ``x[0]=a`` and ``x[-1]=b``. The initial mesh should be relatively | |
fine since the initial mesh is not iteratively refined. | |
y : array_like, shape (m, d) | |
Initial guess for the function values at the mesh nodes. | |
tol : float, optional | |
Desired tolerance of the solution. If we define ``r = y' - f(x, y)``, | |
where y is the found solution, then the solver tries to achieve on each | |
mesh interval ``norm(r / (1 + abs(f)) < tol``, where ``norm`` is | |
estimated in a root mean squared sense (using a numerical quadrature | |
formula). Default is 1e-3. | |
bc_tol : float, optional | |
Desired absolute tolerance for the boundary condition residuals: `bc` | |
value should satisfy ``abs(bc) < bc_tol`` component-wise. | |
Equals to `tol` by default. | |
max_iterations : int, optional | |
Maximum number of iterations of the BVP solver. | |
min_improv : float, optional | |
Early stopping condition. Requires that the relative change in the | |
maximum relative residual between two consecutive iterations be > than | |
`min_improv`, otherwise the solution is returned. Since we do not | |
refine the initial mesh, solutions often converge if the prescribed | |
tolerance cannot be achieved with the given number of mesh nodes. | |
eval_lobatto : bool, optional | |
Whether to estimate the relative residuals using 5-point Lobatto | |
quadrature as opposed to 3-point Lobatto quadrature. The former is | |
significantly more accuracy, but requires ~2*m additional evaluations | |
of `fun` per iteration. | |
verbose : bool, optional | |
Prints some helpful information. Cannot be used together with jit=True. | |
jit : bool, optional | |
Whether to jit compile the whole iteration procedure. | |
Returns | |
------- | |
x : ndarray, shape (m,) | |
Nodes of the final mesh. | |
y : ndarray, shape (m, d) | |
Solution values at the mesh nodes. | |
iteration : int | |
Number of iterations performed. | |
max_rms_res : float | |
Maximum estimated relative residual of the solution. | |
max_bc_res : float | |
Maximum residual of the boundary conditions. | |
success : bool | |
True if the algorithm converged to the desired accuracy (``status=0``). | |
Notes | |
----- | |
This function implements a 4th order collocation algorithm with the | |
control of residuals similar to [1]_. A collocation system is solved | |
by a damped Newton method with an affine-invariant criterion function as | |
described in [3]_. Note that in contrast to [1]_, we do not iteratively | |
refine the mesh using a local criterion but rather iteratively | |
equidistribute the mesh points based on some monitor function, since JAX | |
has poor support for dynamic arrays. | |
Note that in [1]_ integral residuals are defined without normalization | |
by interval lengths. So, their definition is different by a multiplier of | |
h**0.5 (h is an interval length) from the definition used here. | |
References | |
---------- | |
.. [1] J. Kierzenka, L. F. Shampine, "A BVP Solver Based on Residual | |
Control and the Maltab PSE", ACM Trans. Math. Softw., Vol. 27, | |
Number 3, pp. 299-316, 2001. | |
.. [2] L.F. Shampine, P. H. Muir and H. Xu, "A User-Friendly Fortran BVP | |
Solver". | |
.. [3] U. Ascher, R. Mattheij and R. Russell "Numerical Solution of | |
Boundary Value Problems for Ordinary Differential Equations". | |
""" | |
if bc_tol is None: | |
bc_tol = tol | |
# Concatenation of the collocation residuals and the boundary condition residuals. | |
def col_obj(y, x, h): | |
# Evaluate the collocation residuals | |
y = y.reshape(x.shape[0], -1) | |
col_res, (y_middle, f, f_middle) = collocation_fun(fun, y, x, h) | |
# Evaluate the function residuals | |
bc_res = bc(y[0], y[-1]) | |
res = jnp.hstack((col_res.ravel(), bc_res)) | |
extras = (y_middle, f, f_middle.ravel()) | |
return res, extras | |
# Initial values for the extra parameters returned by `col_obj`. | |
# Called once at each call of `newton_solver.solve`. | |
def init_col_obj_extras(y, x, h): | |
m, n = y.shape | |
# As it currently stands, it is needed to compute y_middle in order to construct | |
# the Jacobian of the collocation system. This is unfortunate since it amounts | |
# to an additional evaluation of the ODE function, resulting in higher time | |
# to jit compile the bvp solver. | |
y_middle, _ = compute_ymiddle(fun, y, x, h) | |
extras = (y_middle, jnp.zeros((m, n)), jnp.zeros((m - 1) * n)) | |
return extras | |
get_jac = prepare_jac(fun, bc) | |
newton_solver = BacktrackingNewton(col_obj, get_jac, init_col_obj_extras, | |
tol, bc_tol, jit=jit) | |
if verbose and not jit: | |
print_iteration_header() | |
def loop_continue(vals): | |
(iteration, max_rms_res, max_bc_res, rms_change), _ = vals | |
rms_cond = (max_rms_res > tol) & (rms_change > min_improv) | |
return (iteration < max_iterations) & (rms_cond | (max_bc_res > bc_tol)) | |
def loop_body(vals): | |
(iteration, prev_max_rms_res, _, _), (_, _, x, y) = vals | |
h = jnp.diff(x) | |
y, (res, (_, f, f_middle)) = newton_solver.solve(y, x, h) | |
# Re-use the ODE function evaluations form inside the Newton solver | |
f_middle = f_middle.reshape(-1, f.shape[-1]) | |
res = res.reshape(-1, f.shape[-1]) | |
bc_res = res[-1] | |
col_res = res[:-1] | |
# Fit a cubic spline as an approximate solution. | |
spline_coeffs = fit_cubic_spline_coeffs(y, f, h) | |
# Compute the residual at the mid-points of each interval. | |
# This relation is not trivial, but can be verified. | |
r_middle = 1.5 * col_res / h[:, jnp.newaxis] | |
# Estimate the relative residuals of the approximate solution. | |
if eval_lobatto: | |
rms_res = estimate_rms_residuals_5lobatto(fun, spline_coeffs, | |
x, h, r_middle, f_middle) | |
else: | |
rms_res = estimate_rms_residuals_simpson(r_middle, f_middle) | |
max_rms_res = jnp.max(rms_res) | |
# Compute the ratio of improvement (condition for early stopping) | |
rms_change = jnp.abs(max_rms_res - prev_max_rms_res) \ | |
/ jnp.minimum(max_rms_res, prev_max_rms_res) | |
# Evaluate if the boundary condition is met | |
max_bc_res = jnp.max(abs(bc_res)) | |
# Equdistribute the mesh according to the monitor function | |
x_monitor, monitor = monitor_fourth_derivative(x, h, spline_coeffs) | |
x_new = equidistribute_mesh(x_monitor, monitor, x.shape[0]) | |
y_new, _ = eval_cubic_spline(x, spline_coeffs, x_new) | |
if verbose and not jit: | |
print_iteration_progress(iteration, max_rms_res, max_bc_res, x.shape[0], 0) | |
# Need to return both (x, y) and (x_new, y_new) in case the evaluated solution | |
# meets the stopping criteria. Otherwise, the reported residuals would be inaccurate. | |
return (iteration + 1, max_rms_res, max_bc_res, rms_change), (x, y, x_new, y_new) | |
vals = ((0, jnp.array(jnp.inf), jnp.array(jnp.inf), jnp.array(1.)), (x, y, x, y)) | |
if jit: | |
vals = jax.lax.while_loop(loop_continue, loop_body, vals) | |
else: | |
while loop_continue(vals): | |
vals = loop_body(vals) | |
(iteration, max_rms_res, max_bc_res, _), (x, y, _, _) = vals | |
return x, y, (iteration, max_rms_res, max_bc_res) | |
if __name__ == "__main__": | |
def fun(x, y): | |
return jnp.array([y[1], -jnp.exp(y[0])]) | |
def bc(ya, yb): | |
return jnp.array([ya[0], yb[0]]) | |
N = 100 | |
x = jnp.linspace(0, 1, N) | |
y_a = jnp.zeros((x.size, 2)) | |
y_b = jnp.ones((x.size, 2)) * 3. | |
solve = jax.jit(jax.vmap(lambda x, y: solve_bvp(fun, bc, x, y, tol=1e-5, jit=True))) | |
xx = jnp.r_[x[jnp.newaxis], x[jnp.newaxis]] | |
yy = jnp.r_[y_a[jnp.newaxis], y_b[jnp.newaxis]] | |
_, _, (_, max_rms_res, max_bc_res) = solve(xx, yy) | |
print('Max relative residuals: ', max_rms_res) | |
print('Boundary condition residuals:', max_bc_res) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment