Skip to content

Instantly share code, notes, and snippets.

@clbarnes
Last active May 14, 2021 12:33
Show Gist options
  • Save clbarnes/f998003318c9a2d936a166b559b8601a to your computer and use it in GitHub Desktop.
Save clbarnes/f998003318c9a2d936a166b559b8601a to your computer and use it in GitHub Desktop.
Decorator to wrap a derivative function in order to constrain the output of an ODE solver
"""
Note that this constrains the dependent variable from going *any further* past the constraints.
The ODE will still treat it as if it were at the value of the constraint,
and with a small step size any problems should be minimal,
but you may still have slightly out-of-range numbers in your solution.
"""
import numpy as np
from functools import wraps
def constrain(constraints):
"""
Decorator which wraps a function to be passed to an ODE solver which constrains the solution space.
Example:
@constrain([0, 1])
def f(t, y)
dy_dt = # your ODE
return dy/dt
solver = scipy.integrate.odeint(f, y0) # use any solver you like!
solution = solver.solve()
If solution goes below 0 or above 1, the function f will ignore values of dy_dt which would make it more extreme,
and treat the previous solution as if it were at 0 or 1.
:params constraints: Sequence of (low, high) constraints - use None for unconstrained.
"""
if all(constraint is not None for constraint in constraints):
assert constraints[0] < constraints[1]
def wrap(f):
@wraps(f)
def wrapper(t, y, *args, **kwargs):
lower, upper = constraints
if lower is None:
lower = -np.inf
if upper is None:
upper = np.inf
too_low = y <= lower
too_high = y >= upper
y = np.maximum(y, np.ones(np.shape(y))*lower)
y = np.minimum(y, np.ones(np.shape(y))*upper)
result = f(t, y, *args, **kwargs)
result[too_low] = np.maximum(result[too_low], np.ones(too_low.sum())*lower)
result[too_high] = np.minimum(result[too_high], np.ones(too_high.sum())*upper)
return result
return wrapper
return wrap
@pauvretimo
Copy link

Hello, I'm sorry but I didn't understand what does "too_low.sum()" and "too_high.sum()" line 51 and 52. too_high and too_low, aren't they boolean?

@clbarnes
Copy link
Author

Been a while since I've looked at this, but I believe they are indeed boolean arrays. In python, booleans are a subclass of integer: True == 1 and False == 0. So to count the number of True in a boolean array, you can use sum, because it'll add 1 for every True and 0 for every False.

@pauvretimo
Copy link

I figured out the problem: the order of y and t is inverted in the default configuration of odeint (can be changed with the parameter "tfirst" though) so I just invert them in your code, and it work perfectly now.
Another mistake that I made was returning a tuple in my ODE function instead of a numpy array then the np.miximum could not be applied on “result”.
Thank you for your answer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment