-
-
Save niklasschmitz/b00223b9e9ba2a37ed09539a264bf423 to your computer and use it in GitHub Desktop.
using NLsolve | |
using Zygote | |
using ChainRulesCore | |
using IterativeSolvers | |
using LinearMaps | |
using SparseArrays | |
using LinearAlgebra | |
using BenchmarkTools | |
using Random | |
Random.seed!(1234) | |
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...) | |
result = nlsolve(f, x0; kwargs...) | |
function nlsolve_pullback(Δresult) | |
Δx = Δresult.zero | |
x = result.zero | |
_, f_pullback = rrule_via_ad(config, f, x) | |
JT(v) = f_pullback(v)[2] # w.r.t. x | |
# solve JT*Δfx = -Δx | |
L = LinearMap(JT, length(x0)) | |
Δfx = gmres(L, -Δx) | |
∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables) | |
return (NoTangent(), ∂f, ZeroTangent()) | |
end | |
return result, nlsolve_pullback | |
end | |
const N = 10000 | |
const nonlin = 0.1 | |
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1)) | |
const p0 = randn(N) | |
h(x, p) = A*x + nonlin*x.^2 - p | |
solve_x(p) = nlsolve(x -> h(x, p), zeros(N), method=:anderson, m=10).zero | |
obj(p) = sum(solve_x(p)) | |
# need an rrule for h as Zygote otherwise densifies the sparse matrix A | |
# https://github.com/FluxML/Zygote.jl/issues/931 | |
function ChainRulesCore.rrule(::typeof(h), x, p) | |
y = h(x, p) | |
function my_h_pullback(ȳ) | |
∂x = @thunk(A'ȳ + 2nonlin*x.*ȳ) | |
∂p = @thunk(-ȳ) | |
return (NoTangent(), ∂x, ∂p) | |
end | |
return y, my_h_pullback | |
end | |
g_auto = Zygote.gradient(obj, p0)[1] | |
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)) | |
display(g_auto) | |
display(g_analytic) | |
@show sum(abs, g_auto - g_analytic) / N # 7.613631947123168e-17 | |
@btime Zygote.gradient(obj, p0); # 11.730 ms (784 allocations: 19.87 MiB) | |
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)); # 11.409 ms (626 allocations: 17.50 MiB) | |
import Pkg; Pkg.status() | |
# Status `/tmp/nlsolve/Project.toml` | |
# [6e4b80f9] BenchmarkTools v1.2.2 | |
# [d360d2e6] ChainRulesCore v1.11.6 | |
# [42fd0dbc] IterativeSolvers v0.9.2 | |
# [7a12625a] LinearMaps v3.5.1 | |
# [2774e3e8] NLsolve v4.5.1 | |
# [e88e6eb3] Zygote v0.6.34 | |
# [37e2e46d] LinearAlgebra | |
# [9a3f8284] Random | |
# [2f01184e] SparseArrays |
using NLsolve | |
using Zygote | |
using ChainRulesCore | |
using SparseArrays | |
using LinearAlgebra | |
using Random | |
Random.seed!(1234) | |
using IterativeSolvers | |
using LinearMaps | |
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...) | |
result = nlsolve(f, x0; kwargs...) | |
function nlsolve_pullback(Δresult) | |
Δx = Δresult.zero | |
x = result.zero | |
_, f_pullback = rrule_via_ad(config, f, x) | |
JT(v) = f_pullback(v)[2] # w.r.t. x | |
# solve JT*Δfx = -Δx | |
Δfx = nlsolve(v -> JT(v) + Δx, zero(x); kwargs...).zero | |
∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables) | |
return (NoTangent(), ∂f, ZeroTangent()) | |
end | |
return result, nlsolve_pullback | |
end | |
const N = 100 | |
const nonlin = 0.1 | |
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1)) | |
const p0 = randn(N) | |
h(x, p) = A*x + nonlin*x.^2 - p | |
solve_x(p) = nlsolve(x -> h(x, p), zeros(N), method=:anderson, m=10, show_trace=true).zero | |
obj(p) = sum(solve_x(p)) | |
# need an rrule for h as Zygote otherwise densifies the sparse matrix A | |
# https://github.com/FluxML/Zygote.jl/issues/931 | |
function ChainRulesCore.rrule(::typeof(h), x, p) | |
y = h(x, p) | |
function my_h_pullback(ȳ) | |
∂x = @thunk(A'ȳ + 2nonlin*x.*ȳ) | |
∂p = @thunk(-ȳ) | |
return (NoTangent(), ∂x, ∂p) | |
end | |
return y, my_h_pullback | |
end | |
g_auto = Zygote.gradient(obj, p0)[1] | |
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)) | |
display(g_auto) | |
display(g_analytic) | |
@show sum(abs, g_auto - g_analytic) / N # 8.878502030795765e-11 |
https://github.com/SciML/NonlinearSolve.jl implements an extensive version for this, which was used for https://github.com/SciML/FastDEQ.jl (https://arxiv.org/abs/2201.12240)
Cool! Where's the code? I couldn't find it in NonlinearSolve.jl (which doesn't even depend on ChainRulesCore)
https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/steadystate_adjoint.jl it uses the DiffEqSensitivity machinery (which should be renamed SciMLSensitivity.jl at this point, but I digress) to get all of the AD compatibility without Requires (it throws an error mentioning this if you differentiate without it). It has heuristics for switching between Jacobian-based and Jacobian-free based on size, does the same DiffEqSensitivity thing of Zygote/Enzyme/etc. for VJPs (though it needs to be improved for this case), etc.
Thank you @niklasschmitz for the update. It works really well on Julia 1.7.1.