Last active
December 5, 2019 02:40
-
-
Save jrevels/c165ed338cc7159085238aa54a763fe2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using ForwardDiff, ReverseDiff, BenchmarkTools | |
n = 100 | |
x = rand(n) | |
v = rand(n) | |
tape = ReverseDiff.compile(ReverseDiff.GradientTape(cumprod, rand(n))) | |
vJ = rand(n)' | |
function vecjacobian_new!(vJ, v, tape, x) | |
input = ReverseDiff.input_hook(tape) | |
output = ReverseDiff.output_hook(tape) | |
ReverseDiff.unseed!(input) # clear any "leftover" derivatives from previous calls | |
ReverseDiff.value!(input, x) | |
ReverseDiff.forward_pass!(tape) | |
ReverseDiff.deriv!(output, v) | |
ReverseDiff.reverse_pass!(tape) | |
# Note; we could just say `ReverseDiff.deriv(input)` *is* our `vJ`, in which | |
# case we could remove this line, and the caller could just query `vJ` from | |
# the tape directly via `ReverseDiff.deriv(ReverseDiff.input_hook(tape))`. | |
copyto!(vJ, ReverseDiff.deriv(input)) | |
return nothing | |
end | |
function vecjacobian!(vJ::AbstractArray{<:Number}, v, f, x::AbstractArray{<:Number}) | |
tp = ReverseDiff.InstructionTape() | |
tx = ReverseDiff.track(x, tp) | |
ty = f(tx) | |
ReverseDiff.increment_deriv!(ty, v) | |
ReverseDiff.reverse_pass!(tp) | |
copyto!(vJ, ReverseDiff.deriv(tx)) | |
return nothing | |
end | |
#= | |
Results: | |
julia> @btime vecjacobian_new!($vJ, $v, $tape, $x) evals=1 | |
4.571 μs (0 allocations: 0 bytes) | |
julia> @btime vecjacobian!($vJ, $v, $cumprod, $x) evals=1 | |
26.581 μs (706 allocations: 28.92 KiB) | |
Remember, the speed-up from `ReverseDiff.compile` here isn't "free"; we just | |
moved tape construction/compilation time out of our benchmark. Also keep in mind | |
that using `ReverseDiff.compile` assumes that the target function's computation | |
graph is static. | |
Thus, `ReverseDiff.compile` only makes sense when you're taking vector-jacobians | |
for the same `f` over and over, and `f` does not feature data-dependent control | |
flow. Otherwise, you should just use the old method. | |
=# |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment