Created
November 20, 2018 22:32
-
-
Save jrevels/664e2926c01abb15ac6d92fd4a4788c8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using ForwardDiff, ReverseDiff, BenchmarkTools | |
mutable struct MyJacobianWrapper{fType,tType} <: Function | |
f::fType | |
t::tType | |
end | |
function (ff::MyJacobianWrapper)(u, p) | |
du1 = similar(p, size(u)) | |
ff.f(du1,u,p,ff.t) | |
return du1 | |
end | |
p = rand(3) | |
u = rand(2) | |
function fun(du, u, p, t) | |
a, b, c = p | |
x, y = u | |
du[1] = a*x - b*x*y | |
du[2] = -c*y + x*y | |
return nothing | |
end | |
pf = MyJacobianWrapper(fun, 1) | |
pf(u, p) | |
tape = ReverseDiff.compile(ReverseDiff.GradientTape(pf, (rand(2), rand(3)))) | |
vJ = rand(3)' | |
function vecjacobian!(vJ, v, tape, (u, p)) | |
tu, tp = ReverseDiff.input_hook(tape) | |
output = ReverseDiff.output_hook(tape) | |
ReverseDiff.unseed!(tu) # clear any "leftover" derivatives from previous calls | |
ReverseDiff.unseed!(tp) | |
ReverseDiff.value!(tu, u) | |
ReverseDiff.value!(tp, p) | |
ReverseDiff.forward_pass!(tape) | |
ReverseDiff.increment_deriv!(output, v) | |
ReverseDiff.reverse_pass!(tape) | |
# Note; we could just say `ReverseDiff.deriv(input)` *is* our `vJ`, in which | |
# case we could remove this line, and the caller could just query `vJ` from | |
# the tape directly via `ReverseDiff.deriv(ReverseDiff.input_hook(tape))`. | |
copyto!(vJ, ReverseDiff.deriv(tp)) | |
return nothing | |
end | |
vecjacobian!(vJ, u, tape, (u, p)) | |
vJ ≈ u'ForwardDiff.jacobian(p -> pf(u, p), p) | |
vecjacobian!(vJ, u, tape, ([1.0, 1.0], p)) | |
vJ ≈ u'ForwardDiff.jacobian(p -> pf([1.0, 1.0], p), p) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment