Last active
April 17, 2018 16:24
-
-
Save jverzani/9e5c8546233edeb22db4f27f51de5c4f to your computer and use it in GitHub Desktop.
modify Roots to work with Unitful
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## new framework for Roots | |
module Rts | |
using Compat | |
if VERSION < v"0.7.0" | |
using Missings | |
end | |
using ForwardDiff | |
function D(f::Function, k::Int=1) | |
k < 0 && error("The order of the derivative must be non-negative") | |
k == 0 && return(x -> f(float(x))) | |
D(x -> ForwardDiff.derivative(f, x), k-1) | |
end | |
# A zero is found by specifying: | |
# the method to use <: AbstractUnivariateZeroMethod | |
# the function(s) <: CallableFunction | |
# the initial state through a value for x either x, [a,b], or (a,b) <: AbstractUnivariateZeroState | |
# the options (e.g., tolerances) <: UnivariateZeroOptions | |
### Methods | |
abstract type AbstractUnivariateZeroMethod end | |
abstract type AbstractBisection <: AbstractUnivariateZeroMethod end | |
abstract type AbstractSecant <: AbstractUnivariateZeroMethod end | |
### States | |
abstract type AbstractUnivariateZeroState end | |
mutable struct UnivariateZeroState{T,S} <: AbstractUnivariateZeroState where {T,S} | |
xn1::T | |
xn0::Union{Missing, T} | |
fxn1::S | |
fxn0::Union{Missing, S} | |
steps::Int | |
fnevals::Int | |
stopped::Bool # stopped, butmay not have converged | |
x_converged::Bool # converged via |x_n - x_{n-1}| < ϵ | |
f_converged::Bool # converged via |f(x_n)| < ϵ | |
convergence_failed::Bool | |
message::AbstractString | |
end | |
incfn(o::UnivariateZeroState, k=1) = o.fnevals += k | |
incsteps(o::UnivariateZeroState, k=1) = o.steps += k | |
# initialize state for most methods | |
function init_state(method::Any, fs, x) | |
x1 = float(x) | |
fx1 = fs(x1); fnevals = 1 | |
state = UnivariateZeroState(x1, missing, | |
fx1, missing, | |
0, fnevals, | |
false, false, false, false, | |
"") | |
state | |
end | |
### Options | |
mutable struct UnivariateZeroOptions{Q,R,S,T} | |
xabstol::Q | |
xreltol::R | |
abstol::S | |
reltol::T | |
maxevals::Int | |
maxfnevals::Int | |
verbose::Bool | |
end | |
# Allow for override of default tolerances. Useful, say, for method like bisection | |
function init_options(::Any, | |
state; | |
xabstol=missing, | |
xreltol=missing, | |
abstol=missing, | |
reltol=missing, | |
maxevals::Int=40, | |
maxfnevals::Int=typemax(Int), | |
verbose::Bool=false, | |
kwargs...) | |
## Where we set defaults | |
options = UnivariateZeroOptions(ismissing(xabstol) ? zero(state.xn1) : xabstol, # unit of x | |
ismissing(xreltol) ? eps(state.xn1)/oneunit(state.xn1) : xreltol, # unitless | |
ismissing(abstol) ? 4 * eps(state.fxn1) : abstol, # units of f(x) | |
ismissing(reltol) ? 4 * eps(state.fxn1)/oneunit(state.fxn1) : reltol, # unitless | |
maxevals, maxfnevals, | |
verbose) | |
options | |
end | |
### Functions | |
abstract type CallableFunction end | |
struct DerivativeFree{F} <: CallableFunction | |
f::F | |
end | |
(F::DerivativeFree)(x::Number) = F.f(x) | |
(F::DerivativeFree)(x::Number, n::Int) = F(x, Val{n}) | |
(F::DerivativeFree)(x::Number, ::Type{Val{1}}) = D(F.f)(x) | |
(F::DerivativeFree)(x::Number, ::Type{Val{2}}) = D(F.f,2)(x) | |
(F::DerivativeFree)(x) = F.f(x) | |
struct FirstDerivative{F, Fp} <: CallableFunction | |
f::F | |
fp::Fp | |
end | |
(F::FirstDerivative)(x::Number) = F.f(x) | |
(F::FirstDerivative)(x::Number,n::Int) = F(x, Val{n}) | |
(F::FirstDerivative)(x::Number, ::Type{Val{1}}) = F.fp(x) | |
(F::FirstDerivative)(x::Number, ::Type{Val{2}}) = D(F.fp,1)(x) | |
struct SecondDerivative{F,Fp,Fpp} <: CallableFunction | |
f::F | |
fp::Fp | |
fpp::Fpp | |
end | |
(F::SecondDerivative)(x::Number) = F.f(x) | |
(F::SecondDerivative)(x::Number,n::Int) = F(x, Val{n}) | |
(F::SecondDerivative)(x::Number, ::Type{Val{1}}) = F.fp(x) | |
(F::SecondDerivative)(x::Number, ::Type{Val{2}}) = F.fpp(x) | |
function callable_function(fs) | |
if isa(fs, Tuple) | |
length(fs==1) && return DerivativeFree(fs[1]) | |
length(fs==2) && return FirstDerivative(fs[1],fs[2]) | |
return SecondDerivative(fs[1],fs[2],fs[3]) | |
end | |
DerivativeFree(fs) | |
end | |
## has UnivariateZeroProblem converged? | |
## allow missing values in isapprox | |
_isapprox(a, b, rtol, atol) = _isapprox(Val{ismissing(a) || ismissing(b)}, a, b, rtol, atol) | |
_isapprox(::Type{Val{true}}, a, b, rtol, atol) = false | |
_isapprox(::Type{Val{false}}, a, b, rtol, atol) = isapprox(a, b, rtol=rtol, atol=atol) | |
function assess_convergence(method::Any, state, options) | |
xn0, xn1 = state.xn0, state.xn1 | |
fxn0, fxn1 = state.fxn0, state.fxn1 | |
if (state.x_converged || state.f_converged) | |
return true | |
end | |
if state.steps > options.maxevals | |
state.stopped = true | |
state.message = "too many steps taken." | |
return true | |
end | |
if state.fnevals > options.maxfnevals | |
state.stopped = true | |
state.message = "too many function evaluations taken." | |
return true | |
end | |
if isnan(xn1) | |
state.convergence_failed = true | |
state.message = "NaN produced by algorithm." | |
return true | |
end | |
if isinf(fxn1) | |
state.convergence_failed = true | |
state.message = "Inf produced by algorithm." | |
return true | |
end | |
λ = max(oneunit(real(xn1)), norm(xn1)) | |
if _isapprox(fxn1, fxn0, options.reltol, options.abstol) #abs(fxn1) <= max(options.abstol, λ * options.reltol) | |
state.f_converged = true | |
return true | |
end | |
if _isapprox(xn1, xn0, options.xreltol, options.xabstol) | |
# Heuristic check that f is small too in unitless way | |
tol = max(options.abstol, λ * options.reltol) | |
if abs(fxn1)/oneunit(fxn1) <= cbrt(tol/oneunit(tol)) | |
state.x_converged = true | |
return true | |
end | |
end | |
if state.stopped | |
if state.message == "" | |
error("no message? XXX debug this XXX") | |
end | |
return true | |
end | |
return false | |
end | |
function show_trace(state, xns, fxns, method) | |
converged = state.x_converged || state.f_converged | |
println("Results of univariate zero finding:\n") | |
if converged | |
println("* Converged to: $(xns[end])") | |
println("* Algorithm: $(method)") | |
println("* iterations: $(state.steps)") | |
println("* function evaluations: $(state.fnevals)") | |
state.x_converged && println("* stopped as x_n ≈ x_{n-1} using atol=xabstol, rtol=xreltol") | |
state.f_converged && state.message == "" && println("* stopped as |f(x_n)| ≤ max(δ, max(1,|x|)⋅ϵ) using δ = abstol, ϵ = reltol") | |
state.message != "" && println("* Note: $(state.message)") | |
else | |
println("* Convergence failed: $(state.message)") | |
println("* Algorithm $(method)") | |
end | |
println("") | |
println("Trace:") | |
itr, offset = 0:(endof(xns)-1), 1 | |
for i in itr | |
x_i,fx_i, xi, fxi = "x_$i", "f(x_$i)", xns[i+offset], fxns[i+offset] | |
println(@sprintf("%s = % 18.16f,\t %s = % 18.16f", x_i, float(xi), fx_i, float(fxi))) | |
end | |
println("") | |
end | |
## fs can be f, (f,fp), or (f, fp, fpp) | |
function find_zero(method::AbstractUnivariateZeroMethod, fs, x0; kwargs...) | |
x = float.(x0) | |
F = callable_function(fs) | |
state = init_state(method, F, x) | |
options = init_options(method, state; | |
kwargs...) | |
find_zero(method, F, options, state) | |
end | |
function find_zero(M::AbstractUnivariateZeroMethod, | |
F::CallableFunction, | |
options::UnivariateZeroOptions, | |
state::AbstractUnivariateZeroState | |
) | |
# in case verbose=true | |
if isa(M, AbstractSecant) | |
xns, fxns = [state.xn0, state.xn1], [state.fxn0, state.fxn1] | |
else | |
xns, fxns = [state.xn1], [state.fxn1] | |
end | |
## XXX removed bracket check here | |
while true | |
val = assess_convergence(M, state, options) | |
if val | |
if state.stopped && !(state.x_converged || state.f_converged) | |
## stopped is a heuristic, there was an issue with an approximate derivative | |
## say it converged if pretty close, else say convergence failed. | |
## (Is this a good idea?) | |
xstar, fxstar = state.xn1, state.fxn1 | |
tol = options.abstol | |
if abs(fxstar/oneunit(fxstar)) <= (tol/oneunit(tol))^(2/3) | |
msg = "Algorithm stopped early, but |f(xn)| < ϵ^(2/3), where ϵ = abstol" | |
state.message = state.message == "" ? msg : state.message * "\n\t" * msg | |
state.f_converged = true | |
else | |
state.convergence_failed = true | |
end | |
end | |
if state.x_converged || state.f_converged | |
options.verbose && show_trace(state, xns, fxns, M) | |
return state.xn1 | |
end | |
if state.convergence_failed | |
options.verbose && show_trace(F, state, xns, fxns, M) | |
throw(ConvergenceFailed("Stopped at: xn = $(state.xn1)")) | |
end | |
end | |
update_state(M, F, state, options) | |
if options.verbose | |
push!(xns, state.xn1) | |
push!(fxns, state.fxn1) | |
end | |
end | |
end | |
################################################## | |
## utilities | |
## issue with approx derivative | |
isissue(x) = (x == 0.0) || isnan(x) || isinf(x) | |
## use f[a,b] to approximate f'(x) | |
function _fbracket(a, b, fa, fb) | |
num, den = fb - fa, b - a | |
num == 0 && den == 0 && return Inf, true | |
out = num / den | |
out, isissue(out) | |
end | |
function steff_step(x::T, fx) where {T} | |
thresh = max(1, norm(x)) * sqrt(eps(T)) # max(1, sqrt(abs(x/fx))) * 1e-6 | |
norm(fx) <= thresh ? fx : sign(fx) * thresh | |
end | |
################################################## | |
function init_state(method::AbstractSecant, fs, x) | |
if isa(x, Vector) || isa(x, Tuple) | |
x0, x1 = x[1], x[2] | |
fx0, fx1 = fs(x0), fs(x1) | |
else | |
# need an initial x0,x1 if two not specified | |
x0 = x | |
fx0 = fs(x0) | |
stepsize = max(1/100, min(abs(fx0/oneunit(fx0)), abs(x0/oneunit(x0)/100))) | |
x1 = x0 + stepsize*oneunit(x0) | |
x0, x1, fx0, fx1 = x1, x0, fs(x1), fx0 # switch | |
end | |
state = UnivariateZeroState( promote(x1, x0)..., | |
promote(fx1, fx0)..., | |
0, 2, | |
false, false, false, false, | |
"") | |
state | |
end | |
## Order1 -- SecantMethod | |
mutable struct Order1 <: AbstractSecant end | |
const Secant = Order1 | |
function update_state(method::Order1, Compat.@nospecialize(fs), o::UnivariateZeroState, options::UnivariateZeroOptions) | |
incsteps(o) | |
fp, issue = _fbracket(o.xn0, o.xn1, o.fxn0, o.fxn1) | |
if issue | |
o.stopped = true | |
o.message = "Derivative approximation had issues" | |
return | |
end | |
o.xn0 = o.xn1 | |
o.fxn0 = o.fxn1 | |
o.xn1 = o.xn1 - o.fxn1 / fp | |
o.fxn1 = fs(o.xn1) | |
incfn(o) | |
nothing | |
end | |
################################################## | |
mutable struct Steffensen <: AbstractUnivariateZeroMethod | |
end | |
const Order2 = Steffensen | |
function update_state(method::Steffensen, fs, o::UnivariateZeroState{T}, options::UnivariateZeroOptions{T}) where {T <: Number} | |
S = eltype(o.fxn1) | |
incsteps(o) | |
wn = o.xn1 + steff_step(o.xn1, o.fxn1)::T | |
fwn = fs(wn)::S | |
incfn(o) | |
fp, issue = _fbracket(o.xn1, wn, o.fxn1, fwn) | |
if issue | |
o.stopped = true | |
o.message = "Derivative approximation had issues" | |
return | |
end | |
o.xn0 = o.xn1 | |
o.fxn0 = o.fxn1 | |
o.xn1 = o.xn1 - o.fxn1 / fp #xn1 | |
o.fxn1 = fs(o.xn1) | |
incfn(o) | |
nothing | |
end | |
steffenson(f, x0; kwargs...) = find_zero(f, x0, Steffensen(); kwargs...) | |
################################################## | |
## | |
# ## helper function | |
function adjust_bracket(x0) | |
u, v = float.(promote(x0...)) | |
if u > v | |
u, v = v, u | |
end | |
if isinf(u) | |
u = nextfloat(u) | |
end | |
if isinf(v) | |
v = prevfloat(v) | |
end | |
u, v | |
end | |
function init_state(method::AbstractBisection, fs, x::Union{Tuple{T,T}, Vector{T}}) where {T <: Real} | |
x0, x2 = adjust_bracket(x) | |
y0, y2 = promote(fs(x0), fs(x2)) | |
sign(y0) * sign(y2) > 0 && throw(ArgumentError(bracketing_error)) | |
state = UnivariateZeroState(x0, x2, | |
y0, y2, | |
0, 2, | |
false, false, false, false, | |
"") | |
state | |
end | |
mutable struct FalsePosition <: AbstractBisection | |
reduction_factor::Union{Int, Symbol} | |
FalsePosition(x=:anderson_bjork) = new(x) | |
end | |
function update_state(method::FalsePosition, fs, o, options) | |
fs | |
a, b = o.xn0, o.xn1 | |
fa, fb = o.fxn0, o.fxn1 | |
lambda = fb / (fb - fa) | |
tau = 1e-10 # some engineering to avoid short moves | |
if !(tau < norm(lambda) < 1-tau) | |
lambda = 1/2 | |
end | |
x = b - lambda * (b-a) | |
fx = fs(x) | |
incfn(o) | |
incsteps(o) | |
if iszero(fx) | |
o.xn1 = x | |
o.fxn1 = fx | |
return | |
end | |
if sign(fx)*sign(fb) < 0 | |
a, fa = b, fb | |
else | |
fa = galdino[method.reduction_factor](fa, fb, fx) | |
end | |
b, fb = x, fx | |
o.xn0, o.xn1 = a, b | |
o.fxn0, o.fxn1 = fa, fb | |
nothing | |
end | |
# the 12 reduction factors offered by Galadino | |
galdino = Dict{Union{Int,Symbol},Function}(:1 => (fa, fb, fx) -> fa*fb/(fb+fx), | |
:2 => (fa, fb, fx) -> (fa - fb)/2, | |
:3 => (fa, fb, fx) -> (fa - fx)/(2 + fx/fb), | |
:4 => (fa, fb, fx) -> (fa - fx)/(1 + fx/fb)^2, | |
:5 => (fa, fb, fx) -> (fa -fx)/(1.5 + fx/fb)^2, | |
:6 => (fa, fb, fx) -> (fa - fx)/(2 + fx/fb)^2, | |
:7 => (fa, fb, fx) -> (fa + fx)/(2 + fx/fb)^2, | |
:8 => (fa, fb, fx) -> fa/2, | |
:9 => (fa, fb, fx) -> fa/(1 + fx/fb)^2, | |
:10 => (fa, fb, fx) -> (fa-fx)/4, | |
:11 => (fa, fb, fx) -> fx*fa/(fb+fx), | |
:12 => (fa, fb, fx) -> (fa * (1-fx/fb > 0 ? 1-fx/fb : 1/2)) | |
) | |
# give common names | |
for (nm, i) in [(:pegasus, 1), (:illinois, 8), (:anderson_bjork, 12)] | |
galdino[nm] = galdino[i] | |
end | |
# function find_zero(f, x0::Tuple{T,S}, method::FalsePosition; kwargs...) where {T<:Number,S<:Number} | |
# x = adjust_bracket(x0) | |
# prob, options = derivative_free_setup(method, DerivativeFree(f), x; kwargs...) | |
# find_zero(prob, method, options) | |
# end | |
end | |
using Rts | |
using Unitful | |
g = u"g" | |
## Must adjust tolerances | |
## xabstol eps(units of x) | |
## abstol eps(units of f(x)) | |
## xreltol, reltol dimensionless | |
f(x) = g^2 * sin(x/g) | |
x = (3.0g, 4.0g) | |
Rts.find_zero(Rts.Secant(), f, x) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment