Skip to content

Instantly share code, notes, and snippets.

@jverzani
Last active April 17, 2018 16:24
Show Gist options
  • Save jverzani/9e5c8546233edeb22db4f27f51de5c4f to your computer and use it in GitHub Desktop.
Save jverzani/9e5c8546233edeb22db4f27f51de5c4f to your computer and use it in GitHub Desktop.
modify Roots to work with Unitful
## new framework for Roots
module Rts
using Compat
if VERSION < v"0.7.0"
using Missings
end
using ForwardDiff
function D(f::Function, k::Int=1)
k < 0 && error("The order of the derivative must be non-negative")
k == 0 && return(x -> f(float(x)))
D(x -> ForwardDiff.derivative(f, x), k-1)
end
# A zero is found by specifying:
# the method to use <: AbstractUnivariateZeroMethod
# the function(s) <: CallableFunction
# the initial state through a value for x either x, [a,b], or (a,b) <: AbstractUnivariateZeroState
# the options (e.g., tolerances) <: UnivariateZeroOptions
### Methods
abstract type AbstractUnivariateZeroMethod end
abstract type AbstractBisection <: AbstractUnivariateZeroMethod end
abstract type AbstractSecant <: AbstractUnivariateZeroMethod end
### States
abstract type AbstractUnivariateZeroState end
mutable struct UnivariateZeroState{T,S} <: AbstractUnivariateZeroState where {T,S}
xn1::T
xn0::Union{Missing, T}
fxn1::S
fxn0::Union{Missing, S}
steps::Int
fnevals::Int
stopped::Bool # stopped, butmay not have converged
x_converged::Bool # converged via |x_n - x_{n-1}| < ϵ
f_converged::Bool # converged via |f(x_n)| < ϵ
convergence_failed::Bool
message::AbstractString
end
incfn(o::UnivariateZeroState, k=1) = o.fnevals += k
incsteps(o::UnivariateZeroState, k=1) = o.steps += k
# initialize state for most methods
function init_state(method::Any, fs, x)
x1 = float(x)
fx1 = fs(x1); fnevals = 1
state = UnivariateZeroState(x1, missing,
fx1, missing,
0, fnevals,
false, false, false, false,
"")
state
end
### Options
mutable struct UnivariateZeroOptions{Q,R,S,T}
xabstol::Q
xreltol::R
abstol::S
reltol::T
maxevals::Int
maxfnevals::Int
verbose::Bool
end
# Allow for override of default tolerances. Useful, say, for method like bisection
function init_options(::Any,
state;
xabstol=missing,
xreltol=missing,
abstol=missing,
reltol=missing,
maxevals::Int=40,
maxfnevals::Int=typemax(Int),
verbose::Bool=false,
kwargs...)
## Where we set defaults
options = UnivariateZeroOptions(ismissing(xabstol) ? zero(state.xn1) : xabstol, # unit of x
ismissing(xreltol) ? eps(state.xn1)/oneunit(state.xn1) : xreltol, # unitless
ismissing(abstol) ? 4 * eps(state.fxn1) : abstol, # units of f(x)
ismissing(reltol) ? 4 * eps(state.fxn1)/oneunit(state.fxn1) : reltol, # unitless
maxevals, maxfnevals,
verbose)
options
end
### Functions
abstract type CallableFunction end
struct DerivativeFree{F} <: CallableFunction
f::F
end
(F::DerivativeFree)(x::Number) = F.f(x)
(F::DerivativeFree)(x::Number, n::Int) = F(x, Val{n})
(F::DerivativeFree)(x::Number, ::Type{Val{1}}) = D(F.f)(x)
(F::DerivativeFree)(x::Number, ::Type{Val{2}}) = D(F.f,2)(x)
(F::DerivativeFree)(x) = F.f(x)
struct FirstDerivative{F, Fp} <: CallableFunction
f::F
fp::Fp
end
(F::FirstDerivative)(x::Number) = F.f(x)
(F::FirstDerivative)(x::Number,n::Int) = F(x, Val{n})
(F::FirstDerivative)(x::Number, ::Type{Val{1}}) = F.fp(x)
(F::FirstDerivative)(x::Number, ::Type{Val{2}}) = D(F.fp,1)(x)
struct SecondDerivative{F,Fp,Fpp} <: CallableFunction
f::F
fp::Fp
fpp::Fpp
end
(F::SecondDerivative)(x::Number) = F.f(x)
(F::SecondDerivative)(x::Number,n::Int) = F(x, Val{n})
(F::SecondDerivative)(x::Number, ::Type{Val{1}}) = F.fp(x)
(F::SecondDerivative)(x::Number, ::Type{Val{2}}) = F.fpp(x)
function callable_function(fs)
if isa(fs, Tuple)
length(fs==1) && return DerivativeFree(fs[1])
length(fs==2) && return FirstDerivative(fs[1],fs[2])
return SecondDerivative(fs[1],fs[2],fs[3])
end
DerivativeFree(fs)
end
## has UnivariateZeroProblem converged?
## allow missing values in isapprox
_isapprox(a, b, rtol, atol) = _isapprox(Val{ismissing(a) || ismissing(b)}, a, b, rtol, atol)
_isapprox(::Type{Val{true}}, a, b, rtol, atol) = false
_isapprox(::Type{Val{false}}, a, b, rtol, atol) = isapprox(a, b, rtol=rtol, atol=atol)
function assess_convergence(method::Any, state, options)
xn0, xn1 = state.xn0, state.xn1
fxn0, fxn1 = state.fxn0, state.fxn1
if (state.x_converged || state.f_converged)
return true
end
if state.steps > options.maxevals
state.stopped = true
state.message = "too many steps taken."
return true
end
if state.fnevals > options.maxfnevals
state.stopped = true
state.message = "too many function evaluations taken."
return true
end
if isnan(xn1)
state.convergence_failed = true
state.message = "NaN produced by algorithm."
return true
end
if isinf(fxn1)
state.convergence_failed = true
state.message = "Inf produced by algorithm."
return true
end
λ = max(oneunit(real(xn1)), norm(xn1))
if _isapprox(fxn1, fxn0, options.reltol, options.abstol) #abs(fxn1) <= max(options.abstol, λ * options.reltol)
state.f_converged = true
return true
end
if _isapprox(xn1, xn0, options.xreltol, options.xabstol)
# Heuristic check that f is small too in unitless way
tol = max(options.abstol, λ * options.reltol)
if abs(fxn1)/oneunit(fxn1) <= cbrt(tol/oneunit(tol))
state.x_converged = true
return true
end
end
if state.stopped
if state.message == ""
error("no message? XXX debug this XXX")
end
return true
end
return false
end
function show_trace(state, xns, fxns, method)
converged = state.x_converged || state.f_converged
println("Results of univariate zero finding:\n")
if converged
println("* Converged to: $(xns[end])")
println("* Algorithm: $(method)")
println("* iterations: $(state.steps)")
println("* function evaluations: $(state.fnevals)")
state.x_converged && println("* stopped as x_n ≈ x_{n-1} using atol=xabstol, rtol=xreltol")
state.f_converged && state.message == "" && println("* stopped as |f(x_n)| ≤ max(δ, max(1,|x|)⋅ϵ) using δ = abstol, ϵ = reltol")
state.message != "" && println("* Note: $(state.message)")
else
println("* Convergence failed: $(state.message)")
println("* Algorithm $(method)")
end
println("")
println("Trace:")
itr, offset = 0:(endof(xns)-1), 1
for i in itr
x_i,fx_i, xi, fxi = "x_$i", "f(x_$i)", xns[i+offset], fxns[i+offset]
println(@sprintf("%s = % 18.16f,\t %s = % 18.16f", x_i, float(xi), fx_i, float(fxi)))
end
println("")
end
## fs can be f, (f,fp), or (f, fp, fpp)
function find_zero(method::AbstractUnivariateZeroMethod, fs, x0; kwargs...)
x = float.(x0)
F = callable_function(fs)
state = init_state(method, F, x)
options = init_options(method, state;
kwargs...)
find_zero(method, F, options, state)
end
function find_zero(M::AbstractUnivariateZeroMethod,
F::CallableFunction,
options::UnivariateZeroOptions,
state::AbstractUnivariateZeroState
)
# in case verbose=true
if isa(M, AbstractSecant)
xns, fxns = [state.xn0, state.xn1], [state.fxn0, state.fxn1]
else
xns, fxns = [state.xn1], [state.fxn1]
end
## XXX removed bracket check here
while true
val = assess_convergence(M, state, options)
if val
if state.stopped && !(state.x_converged || state.f_converged)
## stopped is a heuristic, there was an issue with an approximate derivative
## say it converged if pretty close, else say convergence failed.
## (Is this a good idea?)
xstar, fxstar = state.xn1, state.fxn1
tol = options.abstol
if abs(fxstar/oneunit(fxstar)) <= (tol/oneunit(tol))^(2/3)
msg = "Algorithm stopped early, but |f(xn)| < ϵ^(2/3), where ϵ = abstol"
state.message = state.message == "" ? msg : state.message * "\n\t" * msg
state.f_converged = true
else
state.convergence_failed = true
end
end
if state.x_converged || state.f_converged
options.verbose && show_trace(state, xns, fxns, M)
return state.xn1
end
if state.convergence_failed
options.verbose && show_trace(F, state, xns, fxns, M)
throw(ConvergenceFailed("Stopped at: xn = $(state.xn1)"))
end
end
update_state(M, F, state, options)
if options.verbose
push!(xns, state.xn1)
push!(fxns, state.fxn1)
end
end
end
##################################################
## utilities
## issue with approx derivative
isissue(x) = (x == 0.0) || isnan(x) || isinf(x)
## use f[a,b] to approximate f'(x)
function _fbracket(a, b, fa, fb)
num, den = fb - fa, b - a
num == 0 && den == 0 && return Inf, true
out = num / den
out, isissue(out)
end
function steff_step(x::T, fx) where {T}
thresh = max(1, norm(x)) * sqrt(eps(T)) # max(1, sqrt(abs(x/fx))) * 1e-6
norm(fx) <= thresh ? fx : sign(fx) * thresh
end
##################################################
function init_state(method::AbstractSecant, fs, x)
if isa(x, Vector) || isa(x, Tuple)
x0, x1 = x[1], x[2]
fx0, fx1 = fs(x0), fs(x1)
else
# need an initial x0,x1 if two not specified
x0 = x
fx0 = fs(x0)
stepsize = max(1/100, min(abs(fx0/oneunit(fx0)), abs(x0/oneunit(x0)/100)))
x1 = x0 + stepsize*oneunit(x0)
x0, x1, fx0, fx1 = x1, x0, fs(x1), fx0 # switch
end
state = UnivariateZeroState( promote(x1, x0)...,
promote(fx1, fx0)...,
0, 2,
false, false, false, false,
"")
state
end
## Order1 -- SecantMethod
mutable struct Order1 <: AbstractSecant end
const Secant = Order1
function update_state(method::Order1, Compat.@nospecialize(fs), o::UnivariateZeroState, options::UnivariateZeroOptions)
incsteps(o)
fp, issue = _fbracket(o.xn0, o.xn1, o.fxn0, o.fxn1)
if issue
o.stopped = true
o.message = "Derivative approximation had issues"
return
end
o.xn0 = o.xn1
o.fxn0 = o.fxn1
o.xn1 = o.xn1 - o.fxn1 / fp
o.fxn1 = fs(o.xn1)
incfn(o)
nothing
end
##################################################
mutable struct Steffensen <: AbstractUnivariateZeroMethod
end
const Order2 = Steffensen
function update_state(method::Steffensen, fs, o::UnivariateZeroState{T}, options::UnivariateZeroOptions{T}) where {T <: Number}
S = eltype(o.fxn1)
incsteps(o)
wn = o.xn1 + steff_step(o.xn1, o.fxn1)::T
fwn = fs(wn)::S
incfn(o)
fp, issue = _fbracket(o.xn1, wn, o.fxn1, fwn)
if issue
o.stopped = true
o.message = "Derivative approximation had issues"
return
end
o.xn0 = o.xn1
o.fxn0 = o.fxn1
o.xn1 = o.xn1 - o.fxn1 / fp #xn1
o.fxn1 = fs(o.xn1)
incfn(o)
nothing
end
steffenson(f, x0; kwargs...) = find_zero(f, x0, Steffensen(); kwargs...)
##################################################
##
# ## helper function
function adjust_bracket(x0)
u, v = float.(promote(x0...))
if u > v
u, v = v, u
end
if isinf(u)
u = nextfloat(u)
end
if isinf(v)
v = prevfloat(v)
end
u, v
end
function init_state(method::AbstractBisection, fs, x::Union{Tuple{T,T}, Vector{T}}) where {T <: Real}
x0, x2 = adjust_bracket(x)
y0, y2 = promote(fs(x0), fs(x2))
sign(y0) * sign(y2) > 0 && throw(ArgumentError(bracketing_error))
state = UnivariateZeroState(x0, x2,
y0, y2,
0, 2,
false, false, false, false,
"")
state
end
mutable struct FalsePosition <: AbstractBisection
reduction_factor::Union{Int, Symbol}
FalsePosition(x=:anderson_bjork) = new(x)
end
function update_state(method::FalsePosition, fs, o, options)
fs
a, b = o.xn0, o.xn1
fa, fb = o.fxn0, o.fxn1
lambda = fb / (fb - fa)
tau = 1e-10 # some engineering to avoid short moves
if !(tau < norm(lambda) < 1-tau)
lambda = 1/2
end
x = b - lambda * (b-a)
fx = fs(x)
incfn(o)
incsteps(o)
if iszero(fx)
o.xn1 = x
o.fxn1 = fx
return
end
if sign(fx)*sign(fb) < 0
a, fa = b, fb
else
fa = galdino[method.reduction_factor](fa, fb, fx)
end
b, fb = x, fx
o.xn0, o.xn1 = a, b
o.fxn0, o.fxn1 = fa, fb
nothing
end
# the 12 reduction factors offered by Galadino
galdino = Dict{Union{Int,Symbol},Function}(:1 => (fa, fb, fx) -> fa*fb/(fb+fx),
:2 => (fa, fb, fx) -> (fa - fb)/2,
:3 => (fa, fb, fx) -> (fa - fx)/(2 + fx/fb),
:4 => (fa, fb, fx) -> (fa - fx)/(1 + fx/fb)^2,
:5 => (fa, fb, fx) -> (fa -fx)/(1.5 + fx/fb)^2,
:6 => (fa, fb, fx) -> (fa - fx)/(2 + fx/fb)^2,
:7 => (fa, fb, fx) -> (fa + fx)/(2 + fx/fb)^2,
:8 => (fa, fb, fx) -> fa/2,
:9 => (fa, fb, fx) -> fa/(1 + fx/fb)^2,
:10 => (fa, fb, fx) -> (fa-fx)/4,
:11 => (fa, fb, fx) -> fx*fa/(fb+fx),
:12 => (fa, fb, fx) -> (fa * (1-fx/fb > 0 ? 1-fx/fb : 1/2))
)
# give common names
for (nm, i) in [(:pegasus, 1), (:illinois, 8), (:anderson_bjork, 12)]
galdino[nm] = galdino[i]
end
# function find_zero(f, x0::Tuple{T,S}, method::FalsePosition; kwargs...) where {T<:Number,S<:Number}
# x = adjust_bracket(x0)
# prob, options = derivative_free_setup(method, DerivativeFree(f), x; kwargs...)
# find_zero(prob, method, options)
# end
end
using Rts
using Unitful
g = u"g"
## Must adjust tolerances
## xabstol eps(units of x)
## abstol eps(units of f(x))
## xreltol, reltol dimensionless
f(x) = g^2 * sin(x/g)
x = (3.0g, 4.0g)
Rts.find_zero(Rts.Secant(), f, x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment