Skip to content

Instantly share code, notes, and snippets.

@jrevels
Created November 20, 2018 22:32
Show Gist options
  • Save jrevels/664e2926c01abb15ac6d92fd4a4788c8 to your computer and use it in GitHub Desktop.
Save jrevels/664e2926c01abb15ac6d92fd4a4788c8 to your computer and use it in GitHub Desktop.
using ForwardDiff, ReverseDiff, BenchmarkTools
mutable struct MyJacobianWrapper{fType,tType} <: Function
f::fType
t::tType
end
function (ff::MyJacobianWrapper)(u, p)
du1 = similar(p, size(u))
ff.f(du1,u,p,ff.t)
return du1
end
p = rand(3)
u = rand(2)
function fun(du, u, p, t)
a, b, c = p
x, y = u
du[1] = a*x - b*x*y
du[2] = -c*y + x*y
return nothing
end
pf = MyJacobianWrapper(fun, 1)
pf(u, p)
tape = ReverseDiff.compile(ReverseDiff.GradientTape(pf, (rand(2), rand(3))))
vJ = rand(3)'
function vecjacobian!(vJ, v, tape, (u, p))
tu, tp = ReverseDiff.input_hook(tape)
output = ReverseDiff.output_hook(tape)
ReverseDiff.unseed!(tu) # clear any "leftover" derivatives from previous calls
ReverseDiff.unseed!(tp)
ReverseDiff.value!(tu, u)
ReverseDiff.value!(tp, p)
ReverseDiff.forward_pass!(tape)
ReverseDiff.increment_deriv!(output, v)
ReverseDiff.reverse_pass!(tape)
# Note; we could just say `ReverseDiff.deriv(input)` *is* our `vJ`, in which
# case we could remove this line, and the caller could just query `vJ` from
# the tape directly via `ReverseDiff.deriv(ReverseDiff.input_hook(tape))`.
copyto!(vJ, ReverseDiff.deriv(tp))
return nothing
end
vecjacobian!(vJ, u, tape, (u, p))
vJ ≈ u'ForwardDiff.jacobian(p -> pf(u, p), p)
vecjacobian!(vJ, u, tape, ([1.0, 1.0], p))
vJ ≈ u'ForwardDiff.jacobian(p -> pf([1.0, 1.0], p), p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment