Skip to content

Instantly share code, notes, and snippets.

@fhchl
Created February 13, 2023 13:11
Show Gist options
  • Save fhchl/05a0c7c8e802e79932571eb331dabfea to your computer and use it in GitHub Desktop.
Save fhchl/05a0c7c8e802e79932571eb331dabfea to your computer and use it in GitHub Desktop.
Comparing diffrax with scipy.integrate.solve_ivp
import time
import numpy as np
from scipy.integrate import solve_ivp
from scipy.interpolate import CubicSpline
import jax
import jax.numpy as jnp
import diffrax as dfx
def vector_field_np(x, u):
"""Spring-mass-damper system with nonliner drag.
.. math:: m ẍ + r ẋ + r2 ẋ |ẋ| + k x = u
y = x
"""
x1, x2, = x
m, r, r2, k = 1, 2, 3, 4
return np.array([x2, (u -r * x2 - r2 * np.abs(x2) * x2 - k * x1) / m])
def vector_field_jnp(x, u):
"""Spring-mass-damper system with nonliner drag.
.. math:: m ẍ + r ẋ + r2 ẋ |ẋ| + k x = u
y = x
"""
x1, x2, = x
m, r, r2, k = 1, 2, 3, 4
return jnp.array([x2, (u -r * x2 - r2 * jnp.abs(x2) * x2 - k * x1) / m])
def solve_scipy(t, x0, u):
ufun = CubicSpline(t, u)
vector_field = lambda t, x: vector_field_np(x, ufun(t))
x = solve_ivp(vector_field,
t_span=(t[0], t[-1]),
y0=x0,
t_eval=t,
atol=1e-8,
rtol=1e-8).y.T
return x
def solve_diffrax(t, x0, u):
ucoeffs = dfx.backward_hermite_coefficients(t, u)
ufun = dfx.CubicInterpolation(t, ucoeffs)
vector_field = lambda t, x, _: vector_field_jnp(x, ufun.evaluate(t))
terms = dfx.ODETerm(vector_field)
x = dfx.diffeqsolve(
terms,
solver = dfx.Dopri5(),
t0 = t[0],
t1 = t[-1],
dt0 = t[1]-t[0],
y0 = x0,
saveat=dfx.SaveAt(ts=t),
stepsize_controller=dfx.PIDController(atol=1e-8, rtol=1e-8)
).ys
return x
samplerate = 96000
duration = 10
t = np.arange(int(samplerate*duration))/samplerate
x0 = np.array([1., 0.])
u = np.sin(2*np.pi*1*t)
start = time.time()
x_np = solve_scipy(t, x0, u)
print(f"Scipy: {time.time() - start}s")
t = jnp.asarray(t)
u = jnp.asarray(u)
x0 = jnp.asarray(x0)
# without JIT
start = time.time()
x_jnp = solve_diffrax(t, x0, u)
print(f"diffrax without JIT: {time.time() - start}s")
# with JIT and compilation
jit_solve = jax.jit(solve_diffrax)
start = time.time()
x_jnp = jit_solve(t, x0, u)
print(f"diffrax with JIT+compile: {time.time() - start}s")
# with JIT
start = time.time()
x_jnp = jit_solve(t, x0, u)
print(f"diffrax with JIT: {time.time() - start}s")
print(np.mean((x_jnp - x_np)**2 / np.var(x_np)))
assert np.allclose(x_jnp, x_np, rtol=1e-5, atol=1e-5)
@fhchl
Copy link
Author

fhchl commented Feb 13, 2023

On my laptop:

Scipy: 0.11803364753723145s
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
diffrax without JIT: 3.021653413772583s
diffrax with JIT+compile: 2.473569631576538s
diffrax with JIT: 0.05382871627807617s
1.3684053e-12

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment