Only non-stiff ODE solvers are tested since torchdiffeq does not have methods for stiff ODEs. The ODEs are chosen to be representative of models seen in physics and model-informed drug development (MIDD) studies (quantiative systems pharmacology) in order to capture the performance on realistic scenarios.
Below are the timings relative to the fastest method (lower is better). For approximately 1 million ODEs and less, torchdiffeq was more than an order of magnitude slower than DifferentialEquations.jl on every tested problem, and many times substantially slower. Though note that the relative performance of torchdiffeq does increase as the number of ODEs increases.
Additionally, torchdiffeq either exhibited slower gradient calculations or the gradient calculation
diverged. For more reasons on why the calculation of torchdiffeq diverges, see
this manuscript along with many others detailing the stability
of backsolve adjoint methods. Note that when sensealg=BacksolveAdjoint()
is used in DifferentialEquations.jl
on these problems it similarly diverges, indicating it is an issue with that algorithm (and why the DiffEq
documentation does not recommend it this algorithm on such problems!)
Number of ODEs | 3 | 28 | 768 | 3,072 | 12,288 | 49,152 | 196,608 | 786,432 |
---|---|---|---|---|---|---|---|---|
DifferentialEquations.jl | 1.0x | 1.0x | 1.0x | 1.0x | 1.0x | 1.0x | 1.0x | 1.0x |
DifferentialEquations.jl dopri5 | 1.0x | 1.6x | 2.8x | 2.7x | 3.0x | 3.0x | 3.9x | 2.8x |
torchdiffeq dopri5 | 4,900x | 190x | 840x | 220x | 82x | 31x | 24x | 17x |
Number of Parameters | 3 | 4 | 256 | 1,024 |
---|---|---|---|---|
DifferentialEquations.jl | 1.0x | 1.0x | 1.0x | 1.0x |
torchdiffeq dopri5 | 12,000x | 1200x | ---- | ---- |
----
means the gradient calculation diverged. To be clear, torchdiffeq did not successfully
compute the gradient on any of the PDE experiments. All returned an error due to dt
underflow,
leading to the experiment being halted.
- DifferentialEquations.jl: 1.742 ms
- SciPy+Numba: 30.8 ms
- SciPy: 50.2 ms
- SciPy
solve_ivp
: 869 ms - torchscript torchdiffeq (dopri5): 8.60 seconds
- DifferentialEquations.jl: 4.281 ms
- torchscript torchdiffeq: 51.9 seconds
- DifferentialEquations.jl: 1x
- SciPy+Numba: 18x slower
- SciPy: 29x slower
- torchscript torchdiffeq: 4,900x slower
- DifferentialEquations.jl: 1x
- torchscript torchdiffeq: 12,000x
- DifferentialEquations.jl: 2.118 ms
- DifferentialEquations.jl DP5: 3.407 ms
- SciPy+Numba: 2.6 ms
- SciPy: 13.2 ms
- torchscript torchdiffeq (dopri5): 405 ms
- DifferentialEquations.jl: 5.501 ms
- torchscript torchdiffeq: 6.33 seconds
- DifferentialEquations.jl: 1x
- DifferentialEquations.jl DP5: 1.6x slower
- SciPy+Numba: 1.2x slower
- SciPy: 6.2x slower
- torchscript torchdiffeq (dopri5): 190x slower
- DifferentialEquations.jl: 1x
- torchscript torchdiffeq: 1200x slower
- DifferentialEquations.jl: 3.300 ms
- DifferentialEquations.jl DP5: 9.135 ms
- SciPy: 2.2 seconds
- SciPy+Numba: Failed to compile (numpy.ndarray)
- torchscript torchdiffeq (dorpi5): 2.78 seconds
- DifferentialEquations.jl: 1x
- DifferentialEquations.jl DP5: 2.8x slower
- SciPy: 670x slower
- torchscript torchdiffeq (dorpi5): 840x slower
- DifferentialEquations.jl: 14.397 ms
- DifferentialEquations.jl DP5: 38.608 ms
- SciPy: 6.71 seconds
- torchscript torchdiffeq (dorpi5): 3.12 seconds
- DifferentialEquations.jl: 1x
- DifferentialEquations.jl DP5: 2.7x slower
- SciPy: 460x slower
- torchscript torchdiffeq (dopri5): 220x slower
- DifferentialEquations.jl: 64.192 ms
- DifferentialEquations.jl DP5: 192.216 ms
- SciPy: 174 seconds
- torchscript torchdiffeq (dopri5): 5.24 seconds
- DifferentialEquations.jl: 1x
- DifferentialEquations.jl DP5: 3.0x slower
- SciPy: 2,700x slower
- torchscript torchdiffeq (dopri5): 82x slower
- DifferentialEquations.jl: 299.512 ms
- DifferentialEquations.jl DP5: 907.863 ms
- torchscript torchdiffeq (dopri5): 9.29 seconds
- DifferentialEquations.jl: 1x
- DifferentialEquations.jl DP5: 3.0x slower
- torchscript torchdiffeq (dopri5): 31x slower
- DifferentialEquations.jl: 1.586 seconds
- DifferentialEquations.jl DP5: 6.195 seconds
- torchscript torchdiffeq (dopri5): 37.5 seconds
- DifferentialEquations.jl: 1x
- DifferentialEquations.jl DP5: 3.9x slower
- torchscript torchdiffeq (dopri5): 24x slower
- DifferentialEquations.jl: 10.3 seconds
- DifferentialEquations.jl DP5: 29.3 seconds
- torchscript torchdiffeq (dopri5): 172.59 seconds
- DifferentialEquations.jl: 1x
- DifferentialEquations.jl DP5: 2.8x slower
- torchscript torchdiffeq (dopri5): 17x slower
The torchscript versions are kept as separate scripts to allow for the JITing process to occur, and are called before timing to exclude JIT timing, as per the PyTorch documentation suggestions. Python results were scaled by the number of times ran in timeit. Note that the SciPy timing increase in the reaction-diffusion problem is due to lsoda triggering a BDF stuff and utilizing the Jacobian: with this diffusion coefficient this is unnecessary and leads to a large slowdown.
Howdy, I noticed that you're calling the old
odeint
, which is deprecated. I tried reproducing your Lorenz Python results using scipy's newersolve_ivp
, and achieved a dramatic speed boost. On my machine the forward pass timing improved from 45ish ms to ~1.1 ms. Using DOP853 viasolve_ivp
's method kwarg I get timings just a hair above half a ms.