-
-
Save niklasschmitz/b00223b9e9ba2a37ed09539a264bf423 to your computer and use it in GitHub Desktop.
using NLsolve | |
using Zygote | |
using ChainRulesCore | |
using IterativeSolvers | |
using LinearMaps | |
using SparseArrays | |
using LinearAlgebra | |
using BenchmarkTools | |
using Random | |
Random.seed!(1234) | |
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...) | |
result = nlsolve(f, x0; kwargs...) | |
function nlsolve_pullback(Δresult) | |
Δx = Δresult.zero | |
x = result.zero | |
_, f_pullback = rrule_via_ad(config, f, x) | |
JT(v) = f_pullback(v)[2] # w.r.t. x | |
# solve JT*Δfx = -Δx | |
L = LinearMap(JT, length(x0)) | |
Δfx = gmres(L, -Δx) | |
∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables) | |
return (NoTangent(), ∂f, ZeroTangent()) | |
end | |
return result, nlsolve_pullback | |
end | |
const N = 10000 | |
const nonlin = 0.1 | |
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1)) | |
const p0 = randn(N) | |
h(x, p) = A*x + nonlin*x.^2 - p | |
solve_x(p) = nlsolve(x -> h(x, p), zeros(N), method=:anderson, m=10).zero | |
obj(p) = sum(solve_x(p)) | |
# need an rrule for h as Zygote otherwise densifies the sparse matrix A | |
# https://github.com/FluxML/Zygote.jl/issues/931 | |
function ChainRulesCore.rrule(::typeof(h), x, p) | |
y = h(x, p) | |
function my_h_pullback(ȳ) | |
∂x = @thunk(A'ȳ + 2nonlin*x.*ȳ) | |
∂p = @thunk(-ȳ) | |
return (NoTangent(), ∂x, ∂p) | |
end | |
return y, my_h_pullback | |
end | |
g_auto = Zygote.gradient(obj, p0)[1] | |
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)) | |
display(g_auto) | |
display(g_analytic) | |
@show sum(abs, g_auto - g_analytic) / N # 7.613631947123168e-17 | |
@btime Zygote.gradient(obj, p0); # 11.730 ms (784 allocations: 19.87 MiB) | |
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)); # 11.409 ms (626 allocations: 17.50 MiB) | |
import Pkg; Pkg.status() | |
# Status `/tmp/nlsolve/Project.toml` | |
# [6e4b80f9] BenchmarkTools v1.2.2 | |
# [d360d2e6] ChainRulesCore v1.11.6 | |
# [42fd0dbc] IterativeSolvers v0.9.2 | |
# [7a12625a] LinearMaps v3.5.1 | |
# [2774e3e8] NLsolve v4.5.1 | |
# [e88e6eb3] Zygote v0.6.34 | |
# [37e2e46d] LinearAlgebra | |
# [9a3f8284] Random | |
# [2f01184e] SparseArrays |
using NLsolve | |
using Zygote | |
using ChainRulesCore | |
using SparseArrays | |
using LinearAlgebra | |
using Random | |
Random.seed!(1234) | |
using IterativeSolvers | |
using LinearMaps | |
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...) | |
result = nlsolve(f, x0; kwargs...) | |
function nlsolve_pullback(Δresult) | |
Δx = Δresult.zero | |
x = result.zero | |
_, f_pullback = rrule_via_ad(config, f, x) | |
JT(v) = f_pullback(v)[2] # w.r.t. x | |
# solve JT*Δfx = -Δx | |
Δfx = nlsolve(v -> JT(v) + Δx, zero(x); kwargs...).zero | |
∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables) | |
return (NoTangent(), ∂f, ZeroTangent()) | |
end | |
return result, nlsolve_pullback | |
end | |
const N = 100 | |
const nonlin = 0.1 | |
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1)) | |
const p0 = randn(N) | |
h(x, p) = A*x + nonlin*x.^2 - p | |
solve_x(p) = nlsolve(x -> h(x, p), zeros(N), method=:anderson, m=10, show_trace=true).zero | |
obj(p) = sum(solve_x(p)) | |
# need an rrule for h as Zygote otherwise densifies the sparse matrix A | |
# https://github.com/FluxML/Zygote.jl/issues/931 | |
function ChainRulesCore.rrule(::typeof(h), x, p) | |
y = h(x, p) | |
function my_h_pullback(ȳ) | |
∂x = @thunk(A'ȳ + 2nonlin*x.*ȳ) | |
∂p = @thunk(-ȳ) | |
return (NoTangent(), ∂x, ∂p) | |
end | |
return y, my_h_pullback | |
end | |
g_auto = Zygote.gradient(obj, p0)[1] | |
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)) | |
display(g_auto) | |
display(g_analytic) | |
@show sum(abs, g_auto - g_analytic) / N # 8.878502030795765e-11 |
Added a second version nlsolve_rrule_nlsolve.jl
that recursively calls nlsolve again inside the rrule (instead of gmres). This could be extended further to e.g. Newton methods to share/re-use dense Jacobian computations and inversions between primal and reverse solves.
Hi,
Thank you for sharing. Just a message to let you know that the code nlsolve_rrule_nlsolve.jl
fails on the
new version of Julia:
julia> g_auto = Zygote.gradient(obj, p0)[1]
Iter f(x) inf-norm Step 2-norm
------ -------------- --------------
1 2.972698e+00 NaN
2 3.365277e+01 1.349215e+04
3 4.001596e+00 2.537379e+02
4 5.332478e-01 3.256602e+00
5 5.055990e-02 3.535115e-02
6 2.729677e-02 1.034314e-02
7 3.854888e-02 2.297229e-02
8 4.837852e-03 2.994361e-04
9 5.931250e-04 3.540640e-06
10 5.630686e-05 4.839974e-08
11 3.178730e-05 1.258959e-08
12 2.866611e-05 1.210490e-08
13 7.650714e-06 4.651253e-10
14 5.700195e-07 3.826030e-12
15 5.974346e-06 4.443265e-10
16 6.331072e-07 5.080536e-12
17 1.378531e-07 3.095069e-13
18 8.331548e-08 1.096986e-13
19 6.617559e-08 6.484640e-14
20 1.873820e-07 4.225102e-13
21 1.078152e-07 1.607114e-13
22 1.421534e-07 2.987304e-13
23 3.088052e-07 1.817473e-12
24 3.539620e-07 1.882567e-12
25 2.712886e-07 9.741503e-13
26 2.472265e-07 6.618116e-13
27 6.316930e-09 6.378649e-16
ERROR: MethodError: no method matching getindex(::Tangent{Any, NamedTuple{(:method, :initial_x, :zero, :residual_norm, :iterations, :x_converged, :xtol, :f_converged, :ftol, :trace, :f_calls, :g_calls), Tuple{ZeroTangent, ZeroTangent, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}}})
Closest candidates are:
getindex(::Tangent{P, T}, ::Symbol) where {P, T<:NamedTuple} at ~/.julia/packages/ChainRulesCore/IFusD/src/tangent_types/tangent.jl:92
getindex(::Tangent{P, T}, ::Int64) where {P, T<:Union{Tuple, NamedTuple}} at ~/.julia/packages/ChainRulesCore/IFusD/src/tangent_types/tangent.jl:88
getindex(::Tangent, ::Any) where {P, T<:AbstractDict} at ~/.julia/packages/ChainRulesCore/IFusD/src/tangent_types/tangent.jl:96
Stacktrace:
[1] (::var"#nlsolve_pullback#3"{Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:method, :m, :show_trace), Tuple{Symbol, Int64, Bool}}}, Zygote.ZygoteRuleConfig{Zygote.Context}, var"#6#7"{Vector{Float64}}, NLsolve.SolverResults{Float64, Float64, Vector{Float64}, Vector{Float64}}})(Δresult::Tangent{Any, NamedTuple{(:method, :initial_x, :zero, :residual_norm, :iterations, :x_converged, :xtol, :f_converged, :ftol, :trace, :f_calls, :g_calls), Tuple{ZeroTangent, ZeroTangent, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}}})
@ Main ./REPL[10]:4
[2] ZBack
@ ~/.julia/packages/Zygote/umM0L/src/compiler/chainrules.jl:204 [inlined]
[3] (::Zygote.var"#kw_zpullback#42"{var"#nlsolve_pullback#3"{Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:method, :m, :show_trace), Tuple{Symbol, Int64, Bool}}}, Zygote.ZygoteRuleConfig{Zygote.Context}, var"#6#7"{Vector{Float64}}, NLsolve.SolverResults{Float64, Float64, Vector{Float64}, Vector{Float64}}}})(dy::Base.RefValue{Any})
@ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/chainrules.jl:230
[4] Pullback
@ ./REPL[16]:1 [inlined]
[5] (::typeof(∂(solve_x)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
[6] Pullback
@ ./REPL[17]:1 [inlined]
[7] (::Zygote.var"#57#58"{typeof(∂(obj))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:41
[8] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:76
[9] top-level scope
@ REPL[19]:1
julia> versioninfo()
Julia Version 1.7.1
Commit ac5cc99908 (2021-12-22 19:35 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Core(TM) i7-8850H CPU @ 2.60GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-12.0.1 (ORCJIT, skylake)
Everything works fine If I use an older version though. For instance:
julia> versioninfo()
Julia Version 1.6.1
Commit 6aaedecc44 (2021-04-23 05:59 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Core(TM) i7-8850H CPU @ 2.60GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-11.0.1 (ORCJIT, skylake)
Thanks @JulienPascal, I have now updated the example to fix this error, it works now on both julia 1.6.4 and 1.7.1 for me.
Thank you @niklasschmitz for the update. It works really well on Julia 1.7.1.
https://github.com/SciML/NonlinearSolve.jl implements an extensive version for this, which was used for https://github.com/SciML/FastDEQ.jl (https://arxiv.org/abs/2201.12240)
Cool! Where's the code? I couldn't find it in NonlinearSolve.jl (which doesn't even depend on ChainRulesCore)
https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/steadystate_adjoint.jl it uses the DiffEqSensitivity machinery (which should be renamed SciMLSensitivity.jl at this point, but I digress) to get all of the AD compatibility without Requires (it throws an error mentioning this if you differentiate without it). It has heuristics for switching between Jacobian-based and Jacobian-free based on size, does the same DiffEqSensitivity thing of Zygote/Enzyme/etc. for VJPs (though it needs to be improved for this case), etc.
Adapted version of the Zygote nlsolve implicit differentiation examples (initial code by @antoine-levitt and @tkf) discussed in JuliaNLSolvers/NLsolve.jl#205 to use ChainRules instead of Zygote.@adjoint