Created
July 21, 2019 16:50
-
-
Save antoine-levitt/1c08f3181df31b81e912541b17d5acf9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/solvers/anderson.jl b/src/solvers/anderson.jl | |
index 0d47b6a..60dc122 100644 | |
--- a/src/solvers/anderson.jl | |
+++ b/src/solvers/anderson.jl | |
@@ -42,8 +42,8 @@ AndersonCache(df, ::Anderson{0}) = | |
@views function anderson_(df::Union{NonDifferentiable, OnceDifferentiable}, | |
initial_x::AbstractArray{T}, | |
- xtol::T, | |
- ftol::T, | |
+ xtol::Real, | |
+ ftol::Real, | |
iterations::Integer, | |
store_trace::Bool, | |
show_trace::Bool, | |
@@ -80,7 +80,7 @@ AndersonCache(df, ::Anderson{0}) = | |
update!(tr, | |
iter, | |
maximum(abs, fx), | |
- iter > 1 ? sqeuclidean(cache.g, cache.x) : convert(T,NaN), | |
+ iter > 1 ? sqeuclidean(cache.g, cache.x) : convert(real(T),NaN), | |
dt, | |
store_trace, | |
show_trace) | |
@@ -187,5 +187,5 @@ function anderson(df::Union{NonDifferentiable, OnceDifferentiable}, | |
aa_start::Integer, | |
droptol::Real, | |
cache::AndersonCache) where T | |
- anderson_(df, initial_x, convert(T, xtol), convert(T, ftol), iterations, store_trace, show_trace, extended_trace, beta, aa_start, droptol, cache) | |
+ anderson_(df, initial_x, convert(real(T), xtol), convert(real(T), ftol), iterations, store_trace, show_trace, extended_trace, beta, aa_start, droptol, cache) | |
end | |
diff --git a/src/solvers/broyden.jl b/src/solvers/broyden.jl | |
index 97173eb..08fc21e 100644 | |
--- a/src/solvers/broyden.jl | |
+++ b/src/solvers/broyden.jl | |
@@ -23,8 +23,8 @@ end | |
function broyden_(df::Union{NonDifferentiable, OnceDifferentiable}, | |
initial_x::AbstractArray{T}, | |
- xtol::T, | |
- ftol::T, | |
+ xtol::Real, | |
+ ftol::Real, | |
iterations::Integer, | |
store_trace::Bool, | |
show_trace::Bool, | |
@@ -120,7 +120,7 @@ function broyden(df::Union{NonDifferentiable, OnceDifferentiable}, | |
show_trace::Bool, | |
extended_trace::Bool, | |
linesearch) where T | |
- broyden_(df, initial_x, convert(T, xtol), convert(T, ftol), iterations, store_trace, show_trace, extended_trace, linesearch) | |
+ broyden_(df, initial_x, convert(real(T), xtol), convert(real(T), ftol), iterations, store_trace, show_trace, extended_trace, linesearch) | |
end | |
# A derivative-free line search and global convergence | |
diff --git a/test/complex.jl b/test/complex.jl | |
index b4b88b6..610b34a 100644 | |
--- a/test/complex.jl | |
+++ b/test/complex.jl | |
@@ -1,22 +1,24 @@ | |
@testset "complex" begin | |
function f!(F, x) | |
- F[1] = x[1]*x[2] + 1 | |
- F[2] = x[1]^2 + x[2]^2 - 2 | |
+ F[1] = x[1]*x[2] + (1+im) | |
+ F[2] = x[1]^2 + x[2]^2 - (2-3im) | |
end | |
function f_real!(F::AbstractArray{T}, x::AbstractArray{T}) where {T<:Real} | |
f!(reinterpret(Complex{T}, F), reinterpret(Complex{T}, x)) | |
end | |
-for alg in [:trust_region, :newton] | |
- sol = nlsolve(f!, [1.0+0.1im, 2+1im], method = alg, store_trace=true, extended_trace=true) | |
- sol_real = nlsolve(f_real!, reinterpret(Float64, [1.0+0.1im, 2+1im]), method = alg, store_trace=true, extended_trace=true) | |
+for alg in [:newton,:trust_region,:anderson] # TODO add broyden | |
+ sol = nlsolve(f!, [1.0+0.1im, 2+1im], method = alg, store_trace=true, extended_trace=true, iterations=100, m=10, beta=0.01) | |
+ sol_real = nlsolve(f_real!, reinterpret(Float64, [1.0+0.1im, 2+1im]), method = alg, store_trace=true, extended_trace=true, iterations=100, m=10, beta=0.01) | |
@test converged(sol) == converged(sol_real) | |
@test sol.zero ≈ reinterpret(ComplexF64, sol_real.zero) | |
- @test sol.iterations == sol_real.iterations | |
- @test sol.f_calls == sol_real.f_calls | |
- @test sol.g_calls == sol_real.g_calls | |
- @test all(sol_real.trace[i].stepnorm == sol_real.trace[i].stepnorm for i in 2:sol.iterations) | |
- @test all(norm(sol.trace[i].metadata["f(x)"]) ≈ norm(sol_real.trace[i].metadata["f(x)"]) for i in 1:5) | |
+ if alg in (:newton, :trust_region) #those are supposed to be exactly the same (in exact arithmetic) | |
+ @test sol.iterations == sol_real.iterations | |
+ @test sol.f_calls == sol_real.f_calls | |
+ @test sol.g_calls == sol_real.g_calls | |
+ @test all(sol_real.trace[i].stepnorm == sol_real.trace[i].stepnorm for i in 2:sol.iterations) | |
+ @test all(norm(sol.trace[i].metadata["f(x)"]) ≈ norm(sol_real.trace[i].metadata["f(x)"]) for i in 1:5) | |
+ end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment