Created
August 22, 2018 20:30
-
-
Save javipus/ba149aa7f296a16900bf7edc7a4f1e09 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sympy import simplify, Pow | |
from sympy.polys import Poly, domains, polyroots | |
from sympy.assumptions import assuming, ask, Q | |
import signal | |
from functools import wraps | |
# Exceptions # | |
class GaloisOverflow(Exception): | |
pass | |
class TimeOutError(Exception): | |
pass | |
# Helpers # | |
def timeout(seconds = 30, message = 'Function call took too long!'): | |
def decorator(f): | |
def _timeout_handler(signum, frame): | |
raise TimeOutError(message) | |
def wrapper(*args, **kwds): | |
signal.signal(signal.SIGALRM, _timeout_handler) | |
signal.alarm(seconds) | |
try: | |
result = f(*args, **kwds) | |
finally: | |
signal.alarm(0) | |
return result | |
return wrapper | |
return decorator | |
def _solver(deg): | |
""" | |
Helper function to determine what routine to use to find polynomial roots as a function of degree. | |
""" | |
if deg == 1: | |
return polyroots.roots_linear | |
elif deg == 2: | |
return polyroots.roots_quadratic | |
elif deg == 3: | |
return polyroots.roots_cubic | |
elif deg == 4: | |
return polyroots.roots_quartic | |
else: | |
raise GaloisOverflow('Degree must be at most 4!') | |
# Main function # | |
# TODO prevent expr.subs(x, root) from doing floating point evaluation - I want to keep a sympy expression | |
@timeout() | |
def intersect(*args, x = None, interval = None, removeComplex = True, doAssume = [], solver = None, **kwds): | |
""" | |
Calculate pairwise intersection of a family of curves described by polynomials of degree <= 4. | |
@param args: Sympy expressions. Must be polynomials of degree <= 4. | |
@param x: Sympy symbol. Independent variable of all polynomials. If None, polynomial expressions must be of type Poly. | |
@param interval: Return solutions only in given interval. | |
- If tuple (a, b), consider the open interval a < x < b. | |
- If list [a, b], consider the closed interval a <= x <= b. | |
- If None, return solutions in all R. | |
- Half-open intervals like [a, b) not supported for obvious reasons. | |
@param removeComplex: Only return real solutions. If False, it kinda defeats the purpose of a function called intersect, no? | |
@param doAssume: List of assumptions about the polynomial coefficients using the class AssumptionKeys, e.g. Q.positive(a), Q.real(b). | |
@param solver: Preferred solver to find roots. If None, the ones in polyroots are used, depending on the degree of the polynomial. | |
@param kwds: Keyword arguments to be passed to the solver. | |
@return List of tuples of the form (p, q, points) where points is the list of points (x, y) where p and q intersect, e.g. (x**2, x+1, [(1/2 + sqrt(5)/2, 3/2 + sqrt(5)/2), (1/2 - sqrt(5)/2, 3/2 - sqrt(5)/2)]). If no intersections are found, points is the empty list. | |
""" | |
N = len(args) | |
if N < 2: | |
raise TypeError('Need at least two curves to intersect!') | |
if not hasattr(doAssume, '__len__'): | |
doAssume = [doAssume] | |
ps = [] | |
print('Pre-processing...') | |
for p in args: | |
try: | |
n = p.degree() | |
except AttributeError: | |
if x: | |
try: | |
coeffs = list(p.free_symbols - {x}) | |
coeffs += [Pow(coeff, -1) for coeff in coeffs] # just in case a symbol is dividing | |
except AttributeError: # not a sympy expression - could be float, int, etc. | |
coeffs = None | |
p = Poly(p, x, domain = domains.RR[coeffs] if coeffs else domains.RR) | |
n = p.degree() | |
else: | |
raise TypeError('Need to pass a value for x if expressions are not of type Poly!') | |
if n > 4 and not solver: | |
raise GaloisOverflow('Degrees must be at most 4!') | |
ps.append(p) | |
print('Done!') | |
sols = [] | |
print('\nIntersection begins:\n') | |
for i in range(N): | |
for j in range(i+1, N): | |
print('Intersecting p{} with p{}...'.format(i, j)) | |
p, q = ps[i], ps[j] | |
if not solver: | |
deg = max(p.degree(), q.degree()) | |
solver = _solver(deg) | |
x_star = solver(p-q, **kwds) | |
with assuming(*doAssume): # This check takes like forever :( | |
if removeComplex: | |
print('Filtering out complex roots...') | |
x_star = list(filter(lambda _x: ask(Q.real(_x)) in (True, None), x_star)) | |
if interval: | |
print('Filtering out solutions outside of {}...'.format(interval)) | |
if type(interval) == list: | |
cond = Q.nonnegative | |
elif type(interval) == tuple: | |
cond = Q.positive | |
_lower = lambda _x: ask(cond(_x - interval[0])) | |
_upper = lambda _x: ask(cond(interval[1] - _x)) | |
x_star = list(filter(lambda _x: (_lower(_x) and _upper(_x)) in (True, None), x_star)) | |
print('Evaluating solutions (if any)...') | |
xy_star = list(map(lambda _x: (simplify(_x), simplify(p.as_expr().subs(x, _x))), x_star)) | |
sols.append((p.as_expr(), q.as_expr(), list(xy_star))) | |
print('Done!\n') | |
return sols | |
if __name__ == '__main__': | |
from sympy import symbols, init_printing | |
init_printing() | |
x, p, theta, u = symbols('x p theta u', real = True) | |
s, L, k = symbols('s L k', real = True) #, positive = True) | |
parabola = x - u + (p - theta) * s + .5 * k * s**2 / L | |
line = u + theta * s | |
res = intersect(parabola, line, x = s, interval = [0, L]) | |
for intersection in res: | |
print('{} intersects {} at\n{}'.format(*intersection)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment