Last active
February 9, 2022 07:32
-
-
Save niklasschmitz/b00223b9e9ba2a37ed09539a264bf423 to your computer and use it in GitHub Desktop.
NLsolve ChainRules implicit differentiation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using NLsolve | |
using Zygote | |
using ChainRulesCore | |
using IterativeSolvers | |
using LinearMaps | |
using SparseArrays | |
using LinearAlgebra | |
using BenchmarkTools | |
using Random | |
Random.seed!(1234) | |
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...) | |
result = nlsolve(f, x0; kwargs...) | |
function nlsolve_pullback(Δresult) | |
Δx = Δresult.zero | |
x = result.zero | |
_, f_pullback = rrule_via_ad(config, f, x) | |
JT(v) = f_pullback(v)[2] # w.r.t. x | |
# solve JT*Δfx = -Δx | |
L = LinearMap(JT, length(x0)) | |
Δfx = gmres(L, -Δx) | |
∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables) | |
return (NoTangent(), ∂f, ZeroTangent()) | |
end | |
return result, nlsolve_pullback | |
end | |
const N = 10000 | |
const nonlin = 0.1 | |
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1)) | |
const p0 = randn(N) | |
h(x, p) = A*x + nonlin*x.^2 - p | |
solve_x(p) = nlsolve(x -> h(x, p), zeros(N), method=:anderson, m=10).zero | |
obj(p) = sum(solve_x(p)) | |
# need an rrule for h as Zygote otherwise densifies the sparse matrix A | |
# https://github.com/FluxML/Zygote.jl/issues/931 | |
function ChainRulesCore.rrule(::typeof(h), x, p) | |
y = h(x, p) | |
function my_h_pullback(ȳ) | |
∂x = @thunk(A'ȳ + 2nonlin*x.*ȳ) | |
∂p = @thunk(-ȳ) | |
return (NoTangent(), ∂x, ∂p) | |
end | |
return y, my_h_pullback | |
end | |
g_auto = Zygote.gradient(obj, p0)[1] | |
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)) | |
display(g_auto) | |
display(g_analytic) | |
@show sum(abs, g_auto - g_analytic) / N # 7.613631947123168e-17 | |
@btime Zygote.gradient(obj, p0); # 11.730 ms (784 allocations: 19.87 MiB) | |
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)); # 11.409 ms (626 allocations: 17.50 MiB) | |
import Pkg; Pkg.status() | |
# Status `/tmp/nlsolve/Project.toml` | |
# [6e4b80f9] BenchmarkTools v1.2.2 | |
# [d360d2e6] ChainRulesCore v1.11.6 | |
# [42fd0dbc] IterativeSolvers v0.9.2 | |
# [7a12625a] LinearMaps v3.5.1 | |
# [2774e3e8] NLsolve v4.5.1 | |
# [e88e6eb3] Zygote v0.6.34 | |
# [37e2e46d] LinearAlgebra | |
# [9a3f8284] Random | |
# [2f01184e] SparseArrays |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using NLsolve | |
using Zygote | |
using ChainRulesCore | |
using SparseArrays | |
using LinearAlgebra | |
using Random | |
Random.seed!(1234) | |
using IterativeSolvers | |
using LinearMaps | |
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...) | |
result = nlsolve(f, x0; kwargs...) | |
function nlsolve_pullback(Δresult) | |
Δx = Δresult.zero | |
x = result.zero | |
_, f_pullback = rrule_via_ad(config, f, x) | |
JT(v) = f_pullback(v)[2] # w.r.t. x | |
# solve JT*Δfx = -Δx | |
Δfx = nlsolve(v -> JT(v) + Δx, zero(x); kwargs...).zero | |
∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables) | |
return (NoTangent(), ∂f, ZeroTangent()) | |
end | |
return result, nlsolve_pullback | |
end | |
const N = 100 | |
const nonlin = 0.1 | |
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1)) | |
const p0 = randn(N) | |
h(x, p) = A*x + nonlin*x.^2 - p | |
solve_x(p) = nlsolve(x -> h(x, p), zeros(N), method=:anderson, m=10, show_trace=true).zero | |
obj(p) = sum(solve_x(p)) | |
# need an rrule for h as Zygote otherwise densifies the sparse matrix A | |
# https://github.com/FluxML/Zygote.jl/issues/931 | |
function ChainRulesCore.rrule(::typeof(h), x, p) | |
y = h(x, p) | |
function my_h_pullback(ȳ) | |
∂x = @thunk(A'ȳ + 2nonlin*x.*ȳ) | |
∂p = @thunk(-ȳ) | |
return (NoTangent(), ∂x, ∂p) | |
end | |
return y, my_h_pullback | |
end | |
g_auto = Zygote.gradient(obj, p0)[1] | |
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)) | |
display(g_auto) | |
display(g_analytic) | |
@show sum(abs, g_auto - g_analytic) / N # 8.878502030795765e-11 |
Cool! Where's the code? I couldn't find it in NonlinearSolve.jl (which doesn't even depend on ChainRulesCore)
https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/steadystate_adjoint.jl it uses the DiffEqSensitivity machinery (which should be renamed SciMLSensitivity.jl at this point, but I digress) to get all of the AD compatibility without Requires (it throws an error mentioning this if you differentiate without it). It has heuristics for switching between Jacobian-based and Jacobian-free based on size, does the same DiffEqSensitivity thing of Zygote/Enzyme/etc. for VJPs (though it needs to be improved for this case), etc.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://github.com/SciML/NonlinearSolve.jl implements an extensive version for this, which was used for https://github.com/SciML/FastDEQ.jl (https://arxiv.org/abs/2201.12240)