Skip to content

Instantly share code, notes, and snippets.

@yig
Created December 25, 2014 22:33
Show Gist options
  • Save yig/3544de541c98b73badb6 to your computer and use it in GitHub Desktop.
Save yig/3544de541c98b73badb6 to your computer and use it in GitHub Desktop.
'''
Author: Yotam Gingold <yotam (strudel) yotamgingold.com>
License: Public Domain. (I, Yotam Gingold, the author, dedicate any copyright to the Public Domain.)
http://creativecommons.org/publicdomain/zero/1.0/
'''
from math import sqrt
from numpy import asfarray
def sgn( x ):
'''
Returns x/abs(x), or the sign of 'x' as a signed integer (-1 or 1).
'''
if (x < 0.0): return -1
else: return 1
def solve_linear( coeffs ):
'''
Takes as parameter a list/tuple/array of coefficients for a
polynomial of degree up to 1,
where coefficient[N] is the coefficient of x^N.
Returns a list of real solutions to the equation
\sum coefficients[i] x^i = 0.
'''
assert len( coeffs ) == 2
a = float( coeffs[1] )
b = float( coeffs[0] )
## This is the simple expression, but we want to handle zero division better
## return [ -b/a ]
try: x0 = -b/a
except ZeroDivisionError: x0 = -sgn(b) * 1e20
return [ x0 ]
def solve_quadratic( coeffs ):
'''
Takes as parameter a list/tuple/array of coefficients for a
polynomial of degree up to 2,
where coefficient[N] is the coefficient of x^N.
Returns a list of real solutions to the equation
\sum coefficients[i] x^i = 0.
'''
## tested code
assert len( coeffs ) == 2 or len( coeffs ) == 3
if len( coeffs ) == 3 and abs( coeffs[2] ) < eps:
return solve_linear( coeffs[:2] )
if len( coeffs ) == 2:
coeffs = coeffs + [1]
coeffs = map( float, coeffs )
a = coeffs[2]
b = coeffs[1]
c = coeffs[0]
## a * x^2 = 0 => roots when x = 0
if abs( b ) < eps and abs( c ) < eps: return [ 0. ]
discr = b*b - 4*a*c
if discr < 0.: return []
q = -.5 * ( b + sgn(b) * sqrt( discr ) )
# print 'a, b, c, q:', a, b, c, q
## This is the simple expression, but we want to handle zero division better
## return [ q/a, c/q ]
try: x0 = q/a
except ZeroDivisionError: x0 = sgn(q) * 1e20
try: x1 = c/q
except ZeroDivisionError: x1 = sgn(c) * 1e20
return [ x0, x1 ]
def solve_cubic( coeffs ):
'''
Takes as parameter a list/tuple/array of coefficients for a
polynomial of degree up to 3,
where coefficient[N] is the coefficient of x^N.
Returns a list of real solutions to the equation
\sum coefficients[i] x^i = 0.
'''
## tested code
assert len( coeffs ) == 3 or len( coeffs ) == 4
coeffs = asfarray( coeffs )
#print coeffs
if len( coeffs ) == 4:
if abs( coeffs[3] ) < eps:
return solve_quadratic( coeffs[:3] )
coeffs /= coeffs[3]
#print coeffs
## taken from http://home.att.net/~srschmitt/script_exact_cubic.html
## which is similar to what appears in Numerical Recipes in C.
A = coeffs[2]
B = coeffs[1]
C = coeffs[0]
from math import sqrt, acos, cos, pi
Q = (3*B - A*A)/9
R = (9*A*B - 27*C - 2*A*A*A)/54
D = Q*Q*Q + R*R # polynomial discriminant
ts = []
## complex or duplicate roots
if D >= 0:
#print 'R:', R, 'D:', D
S = sgn(R + sqrt(D))*pow(abs(R + sqrt(D)),(1./3.))
T = sgn(R - sqrt(D))*pow(abs(R - sqrt(D)),(1./3.))
#print 'S:',S,'T:',T
ts.append( -A/3 + (S + T) )
else:
try:
th = acos(R/sqrt(-Q*Q*Q))
except ZeroDivisionError:
th = acos( sgn(R) * 1e20 )
ts.append( 2*sqrt(-Q)*cos(th/3) - A/3 )
ts.append( 2*sqrt(-Q)*cos((th + 2*pi)/3) - A/3 )
ts.append( 2*sqrt(-Q)*cos((th + 4*pi)/3) - A/3 )
return ts
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment