Skip to content

Instantly share code, notes, and snippets.

@jtravs
Created April 6, 2013 18:17
Show Gist options
  • Save jtravs/5327056 to your computer and use it in GitHub Desktop.
Save jtravs/5327056 to your computer and use it in GitHub Desktop.
Code to find the zero of a real function
using Test
import Zeros.fzero
# test problems from Table 1 of paper referenced in fzero.jl
# 1
@test_approx_eq fzero(x -> sin(x) - x/2, pi/2, pi) 1.89549426703398094714
# 2
function test2(x)
-2*sum([(2i - 5)^2/(x - i^2)^3 for i = 1:20])
end
@test_approx_eq fzero(test2, 1 + 1e-9, 4 - 1e-9) 3.0229153472730568
@test_approx_eq fzero(test2, 100 + 1e-9, 121 - 1e-9) 110.02653274766949
# 3
function test3(x, a, b)
a*x*exp(b*x)
end
@test_approx_eq fzero(x -> test3(x, -40, -1), -9, 31) 0.0
@test_approx_eq fzero(x -> test3(x, -100, -2), -9, 31) 0.0
@test_approx_eq fzero(x -> test3(x, -200, -3), -9, 31) 0.0
# 4
function test4(x, n, a)
x^n - a
end
@test_approx_eq fzero(x -> test4(x, 4, 0.2), 0, 5) 0.668740304976422
@test_approx_eq fzero(x -> test4(x, 12, 1), 0, 5) 1.0
@test_approx_eq fzero(x -> test4(x, 8, 1), -0.95, 4.05) 1.0
@test_approx_eq fzero(x -> test4(x, 12, 1), -0.95, 4.05) 1.0
# 5
@test_approx_eq fzero(x -> sin(x) - 0.5, 0, 1.5) 0.5235987755982989
# 6
function test6(x, n)
2x*exp(-n) - 2*exp(-n*x) + 1
end
@test_approx_eq fzero(x -> test6(x, 1), 0, 1) 0.42247770964123665
@test_approx_eq fzero(x -> test6(x, 20), 0, 1) 0.03465735902085385
@test_approx_eq fzero(x -> test6(x, 100), 0, 1) 0.006931471805599453
# 7
function test7(x, n)
(1 + (1 - n)^2)*x - (1 - n*x)^2
end
@test_approx_eq fzero(x -> test7(x, 5), 0, 1) 0.0384025518406219
@test_approx_eq fzero(x -> test7(x, 10), 0, 1) 0.0099000099980005
@test_approx_eq fzero(x -> test7(x, 20), 0, 1) 0.0024937500390620117
# 8
function test8(x, n)
x^2 - (1 - x)^n
end
@test_approx_eq fzero(x -> test8(x, 2), 0, 1) 0.5
@test_approx_eq fzero(x -> test8(x, 5), 0, 1) 0.345954815848242
@test_approx_eq fzero(x -> test8(x, 10), 0, 1) 0.24512233375330725
@test_approx_eq fzero(x -> test8(x, 20), 0, 1) 0.16492095727644096
# 9
function test9(x, n)
(1 + (1 - n)^4)*x - (1 - n*x)^4
end
@test_approx_eq fzero(x -> test9(x, 1), 0, 1) 0.2755080409994844
@test_approx_eq fzero(x -> test9(x, 2), 0, 1) 0.1377540204997422
@test_approx_eq fzero(x -> test9(x, 20), 0, 1) 7.668595122185337e-6
# 10
function test10(x, n)
exp(-n*x)*(x - 1) + x^n
end
@test_approx_eq fzero(x -> test10(x, 1), 0, 1) 0.401058137541547
@test_approx_eq fzero(x -> test10(x, 5), 0, 1) 0.5161535187579336
@test_approx_eq fzero(x -> test10(x, 20), 0, 1) 0.5527046666784878
# 11
function test11(x, n)
(n*x - 1)/((n - 1)*x)
end
@test_approx_eq fzero(x -> test11(x, 2), 0.01, 1) 0.5
@test_approx_eq fzero(x -> test11(x, 15), 0.01, 1) 0.0666666666666666
@test_approx_eq fzero(x -> test11(x, 20), 0.01, 1) 0.05
# 12
function test12(x, n)
x^(1/n) - n^(1/n)
end
@test_approx_eq fzero(x -> test12(x, 2), 1, 100) 2
@test_approx_eq fzero(x -> test12(x, 6), 1, 100) 6
@test_approx_eq fzero(x -> test12(x, 33), 1, 100) 33
# 13
function test13(x)
if x == 0
return 0
else
return x/exp(x^2)
end
end
@test_approx_eq fzero(test13, -1, 4) 0
# 14
function test14(x, n)
if x >= 0
return n/20*(x/1.5 + sin(x) - 1)
else
return -n/20
end
end
@test_approx_eq fzero(x -> test14(x, 1), -1e4, pi/2) 0.6238065189616123
@test_approx_eq fzero(x -> test14(x, 40), -1e4, pi/2) 0.6238065189616123
# 15
function test15(x, n)
if x > 2e-3/(1+ n)
return e - 1.859
elseif x < 0
return -0.859
else
return exp((n + 1)*x/2*1e3) - 1.859
end
end
@test_approx_eq fzero(x -> test15(x, 20), -1e4, 1e4) 0.00005905130559421971
@test_approx_eq fzero(x -> test15(x, 40), -1e4, 1e4) 0.000030245790670210097
@test_approx_eq fzero(x -> test15(x, 100), -1e4, 1e4) 0.000012277994232461523
@test_approx_eq fzero(x -> test15(x, 1000), -1e4, 1e4) 1.2388385788997142e-6
module Zeros
export fzero
# the function fzero finds the root of a continuous function within a provided
# interval [a, b], without requiring derivatives.
# It is based on algorithm 4.2 described in:
# 1. G. E. Alefeld, F. A. Potra, and Y. Shi, "Algorithm 748: enclosing zeros of
# continuous functions," ACM Trans. Math. Softw. 21, 327–344 (1995).
# Copyright (c) 2013 John C. Travers <[email protected]>
# Released under the MIT license.
# NOTE: the code here was derived directly from the algorithms described in the
# above paper, no ACM copyrighted code was used. Some implementation tips (but
# no codes) were obtained by reading the Boost C++ sources in math/tools.
# TODO the number of function evaluations could be reduced by factoring
# them out
# finds the root of a continuous real function f within the interval [a, b]
# input:
# tolerance: stopping criteria
# max_iter: maximum number of iterations
#
# output:
# an estimate of the zero of f
function fzero(f::Function, a, b;
tolerance=0.0, lambda=0.7, mu=0.5, max_iter=1000)
# TODO accept a scalar guess and try to construct our own bracket
if a >= b || f(a)f(b) >= 0
error("on input a < b and f(a)f(b) < 0 must both hold")
end
# start with a secant approximation
c = secant(f, a, b)
# re-bracket and check termination
stop, other = bracket(f, a, b, c, tolerance, lambda)
if stop
return other[1]
else
a, b, d = other
end
for n = 2:max_iter
# use either a cubic (if possible) or quadratic interpolation
if n > 2 && distinct(f, a, b, d, e)
c = ipzero(f, a, b, d, e)
else
c = newton_quadratic(f, a, b, d, 2)
end
# re-bracket and check termination
stop, other = bracket(f, a, b, c, tolerance, lambda)
if stop
return other[1]
else
ab, bb, db = other
end
eb = d
# use another cubic (if possible) or quadratic interpolation
if distinct(f, ab, bb, db, eb)
cb = ipzero(f, ab, bb, db, eb)
else
cb = newton_quadratic(f, ab, bb, db, 3)
end
# re-bracket and check termination
stop, other = bracket(f, ab, bb, cb, tolerance, lambda)
if stop
return other[1]
else
ab, bb, db = other
end
# double length secant step, if we fail, use bisection
u = abs(f(ab)) < abs(f(bb)) ? ab : bb
cb = u - 2*f(u)/(f(bb) - f(ab))*(bb - ab)
ch = abs(cb - u) > (bb - ab)/2 ? ab + (bb - ab)/2 : cb
# re-bracket and check termination
stop, other = bracket(f, ab, bb, ch, tolerance, lambda)
if stop
return other[1]
else
ah, bh, dh = other
end
# if not converging fast enough bracket again on a bisection
if bh - ah < mu*(b - a)
a = ah
b = bh
d = dh
e = db
else
e = dh
stop, other = bracket(f, ah, bh, ah + (bh - ah)/2,
tolerance, lambda)
if stop
return other[1]
else
a, b, d = other
end
end
end
error("root not found within max_iter iterations")
end
# calculate a scaled tolerance
# based on algorithm on page 340 of [1]
function tole(a, b, fa, fb, tolerance)
u = abs(fa) < abs(fb) ? abs(a) : abs(b)
2u*eps(1.0) + tolerance
end
# bracket the root
# inputs:
# - f: the function
# - a, b: the current bracket with f(a)f(b) < 0
# - c within (a,b): current best guess of the root
# - tolerance, lambda: stopping criteria
#
# if root is not yet found, return
# stop, (ab, bb, d)
# with:
# - stop == false
# - [ab, bb] a new interval within [a, b] with f(ab)f(bb) < 0
# - d a point not inside [ab, bb]; if d < ab, then f(ab)f(d) > 0,
# and f(d)f(bb) > 0 otherwise
#
# if the root is found, return:
# stop, (x0, fx0)
# with:
# - stop == true
# - x0 the root
# - fx0 the function value at x0
#
# based on algorithm on page 341 of [1]
function bracket(f::Function, a, b, c, tolerance, lambda)
if !(a <= c <= b)
error("c must be in (a,b)")
end
fa = f(a)
fb = f(b)
delta = lambda*tole(a, b, fa, fb, tolerance)
if b - a <= 4delta
c = (a + b)/2
elseif c <= a + 2delta
c = a + 2delta
elseif c >= b - 2delta
c = b - 2delta
end
fc = f(c)
if fc == 0
return true, (c, fc)
elseif sign(fa)*sign(fc) < 0
aa = a
bb = c
db = b
else
aa = c
bb = b
db = a
end
faa = f(aa)
fbb = f(bb)
if bb - aa < 2*tole(aa, bb, faa, fbb, tolerance)
if abs(faa) < abs(fbb)
x0 = aa
fx0 = faa
else
x0 = bb
fx0 = fbb
end
return true, (x0, fx0)
end
return false, (aa, bb, db)
end
# take a secant step, if the resulting guess is very close to a or b, then
# use bisection instead
function secant(f::Function, a, b)
c = a - f(a)/(f(b) - f(a))*(b - a)
tol = eps(1.0)*5
if c <= a + abs(a)*tol || c >= b - abs(b)*tol
return a + (b - a)/2
end
return c
end
# approximate zero of f using quadratic interpolation
# if the new guess is outside [a, b] we use a secant step instead
# based on algorithm on page 330 of [1]
function newton_quadratic(f::Function, a, b, d, k::Int)
fa = f(a)
fb = f(b)
fd = f(d)
B = (fb - fa)/(b - a)
A = ((fd - fb)/(d - b) - B)/(d - a)
if A == 0
return secant(f, a, b)
end
if A*fa > 0
r = a
else
r = b
end
for i = 1:k
r -= (fa + (B + A*(r - b))*(r - a))/(B + A*(2*r - a - b))
end
if r <= a || r >= b
r = secant(f, a, b)
end
return r
end
# approximate zero of f using inverse cubic interpolation
# if the new guess is outside [a, b] we use a quadratic step instead
# based on algorithm on page 333 of [1]
function ipzero(f::Function, a, b, c, d)
fa = f(a)
fb = f(b)
fc = f(c)
fd = f(d)
Q11 = (c - d)*fc/(fd - fc)
Q21 = (b - c)*fb/(fc - fb)
Q31 = (a - b)*fa/(fb - fa)
D21 = (b - c)*fc/(fc - fb)
D31 = (a - b)*fb/(fb - fa)
Q22 = (D21 - Q11)*fb/(fd - fb)
Q32 = (D31 - Q21)*fa/(fc - fa)
D32 = (D31 - Q21)*fc/(fc - fa)
Q33 = (D32 - Q22)*fa/(fd - fa)
c = a + (Q31 + Q32 + Q33)
if (c <= a) || (c >= b)
c = newton_quadratic(f, a, b, d, 3)
end
return c
end
# floating point comparison function
function almost_equal(x, y)
const min_diff = realmin(Float64)*32
abs(x - y) < min_diff
end
# check that all interpolation values are distinct
function distinct(f::Function, a, b, d, e)
f1 = f(a)
f2 = f(b)
f3 = f(d)
f4 = f(e)
!(almost_equal(f1, f2) || almost_equal(f1, f3) || almost_equal(f1, f4) ||
almost_equal(f2, f3) || almost_equal(f2, f4) || almost_equal(f3, f4))
end
end
@johnmyleswhite
Copy link

This looks great. My own style (which isn't standard at all) would be to add type constraints on all of the inputs:

function fzero(f::Function, a::Real, b::Real;
               tolerance::Real=0.0, lambda::Real=0.7, mu::Real=0.5, max_iter::Integer=1000)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment