Last active
June 22, 2023 11:29
-
-
Save santiago-salas-v/70a99b92a77c9152c1fc801822d0f234 to your computer and use it in GitHub Desktop.
Laguerre method python implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numpy import array, exp, real, imag, empty, finfo | |
# Ref. Press, William H., et al. "Numerical recipes in C++." The art of scientific computing 2 (2007): 1002. | |
def zroots(a, polish=False): | |
eps = 1.0e-14 # a small number | |
m = len(a)-1 | |
roots = empty(len(a)-1, dtype=complex) | |
# copy coefficients for successful deflation | |
ad = a.copy() | |
for j in range(m-1, 0-1, -1): | |
x = 0.0 # start at zero to favor convergence to | |
# smallest remaining root, and return the root. | |
ad_v = empty(j+2, dtype=complex) | |
for jj in range(0, j+2): | |
ad_v[jj] = ad[jj] | |
ad_v, x, its = laguerre(ad_v, x) | |
if(abs(imag(x)) <= 2.0 * eps * abs(real(x))): | |
x = real(x) + 0.0*1j | |
roots[j] = x | |
b = ad[j + 1] | |
for jj in range(j, 0-1, -1): | |
c = ad[jj] | |
ad[jj] = b | |
b = x * b + c | |
return roots | |
def laguerre(a, x): | |
ad_v = a | |
mr = 8 | |
mt = 10 | |
maxit = mt * mr | |
eps = finfo(float).eps | |
# EPS here: estimated fractional roundoff error | |
# try to break (rare) limit cycles with | |
# mr different fractional values, once every mt steps, | |
# for maxit total allowed iterations | |
frac = [0.0,0.5,0.25,0.75,0.13,0.38,0.62,0.88,1.0] | |
m = len(a) - 1 | |
for iter in range(1, maxit+1): | |
# loop over iterations up to allowed maximum | |
its = iter | |
b = a[m] | |
err = abs(b) | |
d = f = 0.0 | |
abx = abs(x) | |
for j in range(m-1, 0-1, -1): | |
# efficient computation of the polynomial | |
# and its first two derivatives. f stores P''/2 | |
f = x * f + d | |
d = x * d + b | |
b = x * b + a[j] | |
err = abs(b) + abx * err | |
# estimate of roundoff error in evaluating | |
# polynomial | |
err *= eps | |
if abs(b) <= err: return ad_v, x, its # we are on the root | |
# the generic case: use Laguerre's formula | |
g = d/b | |
g2 = g**2 | |
h = g2 - 2.0 * f/b | |
sq = (float(m-1) * (float(m)*h - g2))**(1/2) | |
gp = g + sq | |
gm = g -sq | |
abp = abs(gp) | |
abm = abs(gm) | |
if abp < abm: gp = gm | |
if max(abp, abm) > 0.0: | |
dx = float(m) / gp | |
else: | |
# equivalent to polar(1+abx, iter) | |
dx = (1+abx)*exp(iter*1j) | |
x1 = x - dx | |
if x == x1: | |
print('converged') | |
return adv_v, its # converged | |
if iter % mt != 0: | |
x = x1 | |
else: | |
x -= frac[int(iter/mt)] * dx | |
print('not converged') | |
raise Exception("too many iterations in laguerre") | |
# very unusual: can occurr only for complex roots. | |
# try a different starting guess. | |
return ad_v, x, its | |
# test the methods | |
all_tests_pass = True | |
for coeffs in [ | |
[-1000.0, 1000002.0, -2000.001, 1.0], | |
[ -6.855188152137764e-15, | |
-7.500004043464861e-01, | |
2.668263562685250e+07, | |
5.993293694242965e+11, | |
3.621535214588474e+11], | |
[432, -144, -3, 1], | |
[-1, 0, 9, 28]]: | |
null_werte = [] | |
p_string_parts = [] | |
soln = zroots(coeffs) | |
for x in soln: | |
null = sum([coeffs[i]*x**i for i in range(len(coeffs))]) | |
null_wert = (null.real**2 + null.imag**2)**(1 / 2) | |
all_tests_pass = all_tests_pass and null_wert < 1e-10 | |
null_werte += [null_wert] | |
for i in range(len(coeffs)): | |
if(abs(imag(coeffs[i])) <= finfo(float).eps): | |
p_string_parts += ['{:.2f}'.format(real(coeffs[i]))] | |
else: | |
p_string_parts += ['({:.2f} + {:g})'.format( | |
real(coeffs[i]), imag(coeffs[i]))] | |
if i == 1: | |
p_string_parts += ['x + '.format(i)] | |
elif i == len(coeffs) - 1 and i != 1: | |
p_string_parts += ['x^{:d}'.format(i)] | |
elif i > 0: | |
p_string_parts += ['x^{:d} + '.format(i)] | |
elif i == 0: | |
p_string_parts += [' + '.format(i)] | |
print('p(x)='+''.join(p_string_parts)) | |
print('soln: ' + str(['{:e} + {:g}j'.format( | |
real(item), imag(item)) for item in soln])) | |
print('f(soln_i):' + str(null_werte) + '\n') | |
print('all solutions correct to 1e-10? ' + str(all_tests_pass)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment