Created
February 13, 2023 13:11
-
-
Save fhchl/05a0c7c8e802e79932571eb331dabfea to your computer and use it in GitHub Desktop.
Comparing diffrax with scipy.integrate.solve_ivp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import numpy as np | |
from scipy.integrate import solve_ivp | |
from scipy.interpolate import CubicSpline | |
import jax | |
import jax.numpy as jnp | |
import diffrax as dfx | |
def vector_field_np(x, u): | |
"""Spring-mass-damper system with nonliner drag. | |
.. math:: m ẍ + r ẋ + r2 ẋ |ẋ| + k x = u | |
y = x | |
""" | |
x1, x2, = x | |
m, r, r2, k = 1, 2, 3, 4 | |
return np.array([x2, (u -r * x2 - r2 * np.abs(x2) * x2 - k * x1) / m]) | |
def vector_field_jnp(x, u): | |
"""Spring-mass-damper system with nonliner drag. | |
.. math:: m ẍ + r ẋ + r2 ẋ |ẋ| + k x = u | |
y = x | |
""" | |
x1, x2, = x | |
m, r, r2, k = 1, 2, 3, 4 | |
return jnp.array([x2, (u -r * x2 - r2 * jnp.abs(x2) * x2 - k * x1) / m]) | |
def solve_scipy(t, x0, u): | |
ufun = CubicSpline(t, u) | |
vector_field = lambda t, x: vector_field_np(x, ufun(t)) | |
x = solve_ivp(vector_field, | |
t_span=(t[0], t[-1]), | |
y0=x0, | |
t_eval=t, | |
atol=1e-8, | |
rtol=1e-8).y.T | |
return x | |
def solve_diffrax(t, x0, u): | |
ucoeffs = dfx.backward_hermite_coefficients(t, u) | |
ufun = dfx.CubicInterpolation(t, ucoeffs) | |
vector_field = lambda t, x, _: vector_field_jnp(x, ufun.evaluate(t)) | |
terms = dfx.ODETerm(vector_field) | |
x = dfx.diffeqsolve( | |
terms, | |
solver = dfx.Dopri5(), | |
t0 = t[0], | |
t1 = t[-1], | |
dt0 = t[1]-t[0], | |
y0 = x0, | |
saveat=dfx.SaveAt(ts=t), | |
stepsize_controller=dfx.PIDController(atol=1e-8, rtol=1e-8) | |
).ys | |
return x | |
samplerate = 96000 | |
duration = 10 | |
t = np.arange(int(samplerate*duration))/samplerate | |
x0 = np.array([1., 0.]) | |
u = np.sin(2*np.pi*1*t) | |
start = time.time() | |
x_np = solve_scipy(t, x0, u) | |
print(f"Scipy: {time.time() - start}s") | |
t = jnp.asarray(t) | |
u = jnp.asarray(u) | |
x0 = jnp.asarray(x0) | |
# without JIT | |
start = time.time() | |
x_jnp = solve_diffrax(t, x0, u) | |
print(f"diffrax without JIT: {time.time() - start}s") | |
# with JIT and compilation | |
jit_solve = jax.jit(solve_diffrax) | |
start = time.time() | |
x_jnp = jit_solve(t, x0, u) | |
print(f"diffrax with JIT+compile: {time.time() - start}s") | |
# with JIT | |
start = time.time() | |
x_jnp = jit_solve(t, x0, u) | |
print(f"diffrax with JIT: {time.time() - start}s") | |
print(np.mean((x_jnp - x_np)**2 / np.var(x_np))) | |
assert np.allclose(x_jnp, x_np, rtol=1e-5, atol=1e-5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
On my laptop: