Skip to content

Instantly share code, notes, and snippets.

@niklasschmitz
Last active February 9, 2022 07:32
Show Gist options
  • Save niklasschmitz/b00223b9e9ba2a37ed09539a264bf423 to your computer and use it in GitHub Desktop.
Save niklasschmitz/b00223b9e9ba2a37ed09539a264bf423 to your computer and use it in GitHub Desktop.
NLsolve ChainRules implicit differentiation
using NLsolve
using Zygote
using ChainRulesCore
using IterativeSolvers
using LinearMaps
using SparseArrays
using LinearAlgebra
using BenchmarkTools
using Random
Random.seed!(1234)
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...)
result = nlsolve(f, x0; kwargs...)
function nlsolve_pullback(Δresult)
Δx = Δresult.zero
x = result.zero
_, f_pullback = rrule_via_ad(config, f, x)
JT(v) = f_pullback(v)[2] # w.r.t. x
# solve JT*Δfx = -Δx
L = LinearMap(JT, length(x0))
Δfx = gmres(L, -Δx)
∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables)
return (NoTangent(), ∂f, ZeroTangent())
end
return result, nlsolve_pullback
end
const N = 10000
const nonlin = 0.1
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
const p0 = randn(N)
h(x, p) = A*x + nonlin*x.^2 - p
solve_x(p) = nlsolve(x -> h(x, p), zeros(N), method=:anderson, m=10).zero
obj(p) = sum(solve_x(p))
# need an rrule for h as Zygote otherwise densifies the sparse matrix A
# https://github.com/FluxML/Zygote.jl/issues/931
function ChainRulesCore.rrule(::typeof(h), x, p)
y = h(x, p)
function my_h_pullback(ȳ)
∂x = @thunk(A'ȳ + 2nonlin*x.*ȳ)
∂p = @thunk(-ȳ)
return (NoTangent(), ∂x, ∂p)
end
return y, my_h_pullback
end
g_auto = Zygote.gradient(obj, p0)[1]
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))
display(g_auto)
display(g_analytic)
@show sum(abs, g_auto - g_analytic) / N # 7.613631947123168e-17
@btime Zygote.gradient(obj, p0); # 11.730 ms (784 allocations: 19.87 MiB)
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)); # 11.409 ms (626 allocations: 17.50 MiB)
import Pkg; Pkg.status()
# Status `/tmp/nlsolve/Project.toml`
# [6e4b80f9] BenchmarkTools v1.2.2
# [d360d2e6] ChainRulesCore v1.11.6
# [42fd0dbc] IterativeSolvers v0.9.2
# [7a12625a] LinearMaps v3.5.1
# [2774e3e8] NLsolve v4.5.1
# [e88e6eb3] Zygote v0.6.34
# [37e2e46d] LinearAlgebra
# [9a3f8284] Random
# [2f01184e] SparseArrays
using NLsolve
using Zygote
using ChainRulesCore
using SparseArrays
using LinearAlgebra
using Random
Random.seed!(1234)
using IterativeSolvers
using LinearMaps
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...)
result = nlsolve(f, x0; kwargs...)
function nlsolve_pullback(Δresult)
Δx = Δresult.zero
x = result.zero
_, f_pullback = rrule_via_ad(config, f, x)
JT(v) = f_pullback(v)[2] # w.r.t. x
# solve JT*Δfx = -Δx
Δfx = nlsolve(v -> JT(v) + Δx, zero(x); kwargs...).zero
∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables)
return (NoTangent(), ∂f, ZeroTangent())
end
return result, nlsolve_pullback
end
const N = 100
const nonlin = 0.1
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
const p0 = randn(N)
h(x, p) = A*x + nonlin*x.^2 - p
solve_x(p) = nlsolve(x -> h(x, p), zeros(N), method=:anderson, m=10, show_trace=true).zero
obj(p) = sum(solve_x(p))
# need an rrule for h as Zygote otherwise densifies the sparse matrix A
# https://github.com/FluxML/Zygote.jl/issues/931
function ChainRulesCore.rrule(::typeof(h), x, p)
y = h(x, p)
function my_h_pullback(ȳ)
∂x = @thunk(A'ȳ + 2nonlin*x.*ȳ)
∂p = @thunk(-ȳ)
return (NoTangent(), ∂x, ∂p)
end
return y, my_h_pullback
end
g_auto = Zygote.gradient(obj, p0)[1]
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))
display(g_auto)
display(g_analytic)
@show sum(abs, g_auto - g_analytic) / N # 8.878502030795765e-11
@niklasschmitz
Copy link
Author

Added a second version nlsolve_rrule_nlsolve.jl that recursively calls nlsolve again inside the rrule (instead of gmres). This could be extended further to e.g. Newton methods to share/re-use dense Jacobian computations and inversions between primal and reverse solves.

@JulienPascal
Copy link

Hi,

Thank you for sharing. Just a message to let you know that the code nlsolve_rrule_nlsolve.jl fails on the
new version of Julia:

julia> g_auto = Zygote.gradient(obj, p0)[1]
Iter     f(x) inf-norm    Step 2-norm 
------   --------------   --------------
     1     2.972698e+00              NaN
     2     3.365277e+01     1.349215e+04
     3     4.001596e+00     2.537379e+02
     4     5.332478e-01     3.256602e+00
     5     5.055990e-02     3.535115e-02
     6     2.729677e-02     1.034314e-02
     7     3.854888e-02     2.297229e-02
     8     4.837852e-03     2.994361e-04
     9     5.931250e-04     3.540640e-06
    10     5.630686e-05     4.839974e-08
    11     3.178730e-05     1.258959e-08
    12     2.866611e-05     1.210490e-08
    13     7.650714e-06     4.651253e-10
    14     5.700195e-07     3.826030e-12
    15     5.974346e-06     4.443265e-10
    16     6.331072e-07     5.080536e-12
    17     1.378531e-07     3.095069e-13
    18     8.331548e-08     1.096986e-13
    19     6.617559e-08     6.484640e-14
    20     1.873820e-07     4.225102e-13
    21     1.078152e-07     1.607114e-13
    22     1.421534e-07     2.987304e-13
    23     3.088052e-07     1.817473e-12
    24     3.539620e-07     1.882567e-12
    25     2.712886e-07     9.741503e-13
    26     2.472265e-07     6.618116e-13
    27     6.316930e-09     6.378649e-16
ERROR: MethodError: no method matching getindex(::Tangent{Any, NamedTuple{(:method, :initial_x, :zero, :residual_norm, :iterations, :x_converged, :xtol, :f_converged, :ftol, :trace, :f_calls, :g_calls), Tuple{ZeroTangent, ZeroTangent, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}}})
Closest candidates are:
  getindex(::Tangent{P, T}, ::Symbol) where {P, T<:NamedTuple} at ~/.julia/packages/ChainRulesCore/IFusD/src/tangent_types/tangent.jl:92
  getindex(::Tangent{P, T}, ::Int64) where {P, T<:Union{Tuple, NamedTuple}} at ~/.julia/packages/ChainRulesCore/IFusD/src/tangent_types/tangent.jl:88
  getindex(::Tangent, ::Any) where {P, T<:AbstractDict} at ~/.julia/packages/ChainRulesCore/IFusD/src/tangent_types/tangent.jl:96
Stacktrace:
 [1] (::var"#nlsolve_pullback#3"{Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:method, :m, :show_trace), Tuple{Symbol, Int64, Bool}}}, Zygote.ZygoteRuleConfig{Zygote.Context}, var"#6#7"{Vector{Float64}}, NLsolve.SolverResults{Float64, Float64, Vector{Float64}, Vector{Float64}}})(Δresult::Tangent{Any, NamedTuple{(:method, :initial_x, :zero, :residual_norm, :iterations, :x_converged, :xtol, :f_converged, :ftol, :trace, :f_calls, :g_calls), Tuple{ZeroTangent, ZeroTangent, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}}})
   @ Main ./REPL[10]:4
 [2] ZBack
   @ ~/.julia/packages/Zygote/umM0L/src/compiler/chainrules.jl:204 [inlined]
 [3] (::Zygote.var"#kw_zpullback#42"{var"#nlsolve_pullback#3"{Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:method, :m, :show_trace), Tuple{Symbol, Int64, Bool}}}, Zygote.ZygoteRuleConfig{Zygote.Context}, var"#6#7"{Vector{Float64}}, NLsolve.SolverResults{Float64, Float64, Vector{Float64}, Vector{Float64}}}})(dy::Base.RefValue{Any})
   @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/chainrules.jl:230
 [4] Pullback
   @ ./REPL[16]:1 [inlined]
 [5] (::typeof(∂(solve_x)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
   @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [6] Pullback
   @ ./REPL[17]:1 [inlined]
 [7] (::Zygote.var"#57#58"{typeof(∂(obj))})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:41
 [8] gradient(f::Function, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:76
 [9] top-level scope
   @ REPL[19]:1
julia> versioninfo()
Julia Version 1.7.1
Commit ac5cc99908 (2021-12-22 19:35 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-8850H CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, skylake)

Everything works fine If I use an older version though. For instance:

julia> versioninfo()
Julia Version 1.6.1
Commit 6aaedecc44 (2021-04-23 05:59 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-8850H CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, skylake)

@niklasschmitz
Copy link
Author

Thanks @JulienPascal, I have now updated the example to fix this error, it works now on both julia 1.6.4 and 1.7.1 for me.

@JulienPascal
Copy link

Thank you @niklasschmitz for the update. It works really well on Julia 1.7.1.

@ChrisRackauckas
Copy link

@antoine-levitt
Copy link

Cool! Where's the code? I couldn't find it in NonlinearSolve.jl (which doesn't even depend on ChainRulesCore)

@ChrisRackauckas
Copy link

https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/steadystate_adjoint.jl it uses the DiffEqSensitivity machinery (which should be renamed SciMLSensitivity.jl at this point, but I digress) to get all of the AD compatibility without Requires (it throws an error mentioning this if you differentiate without it). It has heuristics for switching between Jacobian-based and Jacobian-free based on size, does the same DiffEqSensitivity thing of Zygote/Enzyme/etc. for VJPs (though it needs to be improved for this case), etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment