Last active
October 12, 2020 16:28
-
-
Save ptmcg/620fc74025da8dd73b7e3b180659e30d to your computer and use it in GitHub Desktop.
Runge-Kutta integrator in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# rk.py | |
# | |
# Copyright 2020, Paul McGuire | |
# | |
from typing import Callable, Sequence | |
from array import array | |
class RKIntegrator: | |
""" | |
Class used to perform Runge-Kutta numerical integration of set of ODE's. | |
Works with a Sequence[float] representing current state [x, x', x'', ...]. | |
Parameters: | |
- dt (float): time step | |
- deriv_fn (callable): function for computing derivative at | |
time t; given state [x, x', x'', ...], returns vector | |
of computed derivatives [x', x'', x''', ...] | |
- degree (int): optional argument to indicate the degree of the | |
state X (number of derivatives maintained in state) | |
- init_conds (list(float)): optional argument giving initial | |
conditions for modeling state X | |
Must specify degree or init_conds. | |
""" | |
def __init__(self, | |
dt: float, | |
deriv_fn: Callable[[float, Sequence[float]], Sequence[float]], | |
degree: int = 0, | |
init_conds: Sequence[float] = None): | |
if (degree == 0 and init_conds is None) or (degree != 0 and init_conds is not None): | |
raise ValueError("must specify degree or initial conditions") | |
self.dt = float(dt) | |
self.t = 0.0 | |
if init_conds is not None: | |
self.x = array('f', init_conds[:]) | |
else: | |
self.x = array('f', [0.0] * degree) | |
self.deriv_fn = deriv_fn | |
def integrate(self): | |
dt = self.dt | |
dt_c = 0.0 | |
dx_func = self.deriv_fn | |
while True: | |
t2 = self.t + dt / 2.0 | |
delx0 = [dx_i * dt for dx_i in dx_func(self.t, self.x)] | |
xv = [x_i + delx0_i / 2.0 for x_i, delx0_i in zip(self.x, delx0)] | |
delx1 = [dx_i * dt for dx_i in dx_func(t2, xv)] | |
xv = [x_i + delx1_i / 2.0 for x_i, delx1_i in zip(self.x, delx1)] | |
delx2 = [dx_i * dt for dx_i in dx_func(t2, xv)] | |
xv = [x_i + delx2_i for x_i, delx2_i in zip(self.x, delx2)] | |
# avoid accumulating floating-point error | |
# (from https://en.wikipedia.org/wiki/Kahan_summation_algorithm) | |
y = dt - dt_c | |
t = self.t + y | |
dt_c = t - self.t - y | |
self.t = t | |
dx = dx_func(self.t, xv) | |
self.x = array('f', (x_i + (delx0_i + dx_i * dt + 2.0 * (delx1_i + delx2_i)) / 6.0 | |
for x_i, dx_i, delx0_i, delx1_i, delx2_i in zip(self.x, dx, delx0, delx1, delx2))) | |
yield self.t, self.x | |
def constant_acceleration(a: float) -> Callable[[float, Sequence[float]], Sequence[float]]: | |
# helpful method to create a derivative function representing constant acceleration | |
# (common for modeling gravity) | |
def _inner(t, x_vec): | |
return [ | |
x_vec[1], | |
a | |
] | |
return _inner | |
if __name__ == "__main__": | |
is_whole = lambda x: abs(x - round(x)) < 1e-9 | |
# example of motion with constant acceleration = 4 m/s², with dt=0.1 sec | |
# (could also have specified with init_conds=[0.0, 0.0]) | |
accel = 4.0 | |
rk = RKIntegrator(dt=0.1, deriv_fn=constant_acceleration(accel), degree=2) | |
for t, x in rk.integrate(): | |
if t > 10: | |
break | |
if is_whole(t): | |
# print current state, plus actual solution for x = ½at², to compare against x[0] | |
print(round(t), ', '.join('%.2f' % xx for xx in x), 'actual x=', 0.5 * accel * t * t) | |
print() | |
# implementation to compare with http://doswa.com/blog/2009/01/02/fourth-order-runge-kutta-numerical-integration/ | |
def dx2(tt: float, x_vec: Sequence[float]) -> float: | |
stiffness = 1 | |
damping = -0.005 | |
return -stiffness * x_vec[0] - damping * x_vec[1] | |
def dx(tt: float, x_vec: Sequence[float]) -> Sequence[float]: | |
return [ | |
x_vec[1], | |
dx2(tt, x_vec) | |
] | |
rk = RKIntegrator(dt=1.0 / 40.0, deriv_fn=dx, init_conds=[50.0, 5.0]) | |
for t, x in rk.integrate(): | |
if t > 100.1: | |
break | |
if is_whole(t): | |
print(round(t), ', '.join('%.2f' % xx for xx in x)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment