Skip to content

Instantly share code, notes, and snippets.

@jrevels
Last active December 5, 2019 02:40
Show Gist options
  • Save jrevels/c165ed338cc7159085238aa54a763fe2 to your computer and use it in GitHub Desktop.
Save jrevels/c165ed338cc7159085238aa54a763fe2 to your computer and use it in GitHub Desktop.
using ForwardDiff, ReverseDiff, BenchmarkTools
n = 100
x = rand(n)
v = rand(n)
tape = ReverseDiff.compile(ReverseDiff.GradientTape(cumprod, rand(n)))
vJ = rand(n)'
function vecjacobian_new!(vJ, v, tape, x)
input = ReverseDiff.input_hook(tape)
output = ReverseDiff.output_hook(tape)
ReverseDiff.unseed!(input) # clear any "leftover" derivatives from previous calls
ReverseDiff.value!(input, x)
ReverseDiff.forward_pass!(tape)
ReverseDiff.deriv!(output, v)
ReverseDiff.reverse_pass!(tape)
# Note; we could just say `ReverseDiff.deriv(input)` *is* our `vJ`, in which
# case we could remove this line, and the caller could just query `vJ` from
# the tape directly via `ReverseDiff.deriv(ReverseDiff.input_hook(tape))`.
copyto!(vJ, ReverseDiff.deriv(input))
return nothing
end
function vecjacobian!(vJ::AbstractArray{<:Number}, v, f, x::AbstractArray{<:Number})
tp = ReverseDiff.InstructionTape()
tx = ReverseDiff.track(x, tp)
ty = f(tx)
ReverseDiff.increment_deriv!(ty, v)
ReverseDiff.reverse_pass!(tp)
copyto!(vJ, ReverseDiff.deriv(tx))
return nothing
end
#=
Results:
julia> @btime vecjacobian_new!($vJ, $v, $tape, $x) evals=1
4.571 μs (0 allocations: 0 bytes)
julia> @btime vecjacobian!($vJ, $v, $cumprod, $x) evals=1
26.581 μs (706 allocations: 28.92 KiB)
Remember, the speed-up from `ReverseDiff.compile` here isn't "free"; we just
moved tape construction/compilation time out of our benchmark. Also keep in mind
that using `ReverseDiff.compile` assumes that the target function's computation
graph is static.
Thus, `ReverseDiff.compile` only makes sense when you're taking vector-jacobians
for the same `f` over and over, and `f` does not feature data-dependent control
flow. Otherwise, you should just use the old method.
=#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment