Last active
May 27, 2024 13:27
-
-
Save maxiimilian/67113eb1d60a5d8ceca212fbcad100c9 to your computer and use it in GitHub Desktop.
Function to look for all roots within given interval/bracket based on scipy's `root_scalar`. Resolution `n` of this function only needs to be high enough so that all sign changes of roots are covered.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import warnings | |
from typing import Callable, Iterable | |
import numpy as np | |
from scipy.optimize import root_scalar | |
def multi_root(f: Callable, bracket: Iterable[float], args: Iterable = (), n: int = 30) -> np.ndarray: | |
""" Find all roots of f in `bracket`, given that resolution `n` covers the sign change. | |
Fine-grained root finding is performed with `scipy.optimize.root_scalar`. | |
Parameters | |
---------- | |
f: Callable | |
Function to be evaluated | |
bracket: Sequence of two floats | |
Specifies interval within which roots are searched. | |
args: Iterable, optional | |
Iterable passed to `f` for evaluation | |
n: int | |
Number of points sampled equidistantly from bracket to evaluate `f`. | |
Resolution has to be high enough to cover sign changes of all roots but not finer than that. | |
Actual roots are found using `scipy.optimize.root_scalar`. | |
Returns | |
------- | |
roots: np.ndarray | |
Array containing all unique roots that were found in `bracket`. | |
""" | |
# Evaluate function in given bracket | |
x = np.linspace(*bracket, n) | |
y = f(x, *args) | |
# Find where adjacent signs are not equal | |
sign_changes = np.where(np.sign(y[:-1]) != np.sign(y[1:]))[0] | |
# Find roots around sign changes | |
root_finders = ( | |
root_scalar( | |
f=f, | |
args=args, | |
bracket=(x[s], x[s+1]) | |
) | |
for s in sign_changes | |
) | |
roots = np.array([ | |
r.root if r.converged else np.nan | |
for r in root_finders | |
]) | |
if np.any(np.isnan(roots)): | |
warnings.warn("Not all root finders converged for estimated brackets! Maybe increase resolution `n`.") | |
roots = roots[~np.isnan(roots)] | |
roots_unique = np.unique(roots) | |
if len(roots_unique) != len(roots): | |
warnings.warn("One root was found multiple times. " | |
"Try to increase or decrease resolution `n` to see if this warning disappears.") | |
return roots_unique | |
if __name__ == '__main__': | |
def poly1(x): | |
return (x+4)*(x+2)*(x-1)*(x-5) | |
roots = multi_root(poly1, [-5, 6]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Extremely thanks