Skip to content

Instantly share code, notes, and snippets.

View danielkelshaw's full-sized avatar
:octocat:
Approximating the Posterior

Daniel Kelshaw danielkelshaw

:octocat:
Approximating the Posterior
View GitHub Profile
"""A 1-dimensional example of adaptive mesh refinement in JAX. In this case, a simple
implementation of quadrature.
Static shapes don't mean you can't do this. Heap allocation is *not* necessary!
Not extensively tested; any bugs leave a comment below.
"""
import functools as ft
from collections.abc import Callable
@RicardoDominguez
RicardoDominguez / jax_bvp_solver.py
Last active June 11, 2024 17:41
BVP solver in JAX based on scipy.integrate.solve_bvp
"""Boundary value problem solver."""
import jax
import jax.numpy as jnp
# ------------------------------------------------------------------------------------------
# Linear solver for bordered almost block diagonal (BABD) systems
# ------------------------------------------------------------------------------------------
# Implementation as described in [1] Section 2.1 (structural orthogonal factorization).
import torch
def jacobian(y, x, create_graph=False):
jac = []
flat_y = y.reshape(-1)
grad_y = torch.zeros_like(flat_y)
for i in range(len(flat_y)):
grad_y[i] = 1.
grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
jac.append(grad_x.reshape(x.shape))