Created
June 18, 2020 20:27
-
-
Save YingboMa/c22dcf8239a62e01b27ac679dfe5d4c5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using ForwardDiff | |
goo((x, y, z),) = [x^2*z, x*y*z, abs(z)-y] | |
foo((x, y, z),) = [x^2*z, x*y*z, abs(z)-y] | |
function foo(u::Vector{ForwardDiff.Dual{T,V,P}}) where {T,V,P} | |
# unpack: AoS -> SoA | |
vs = ForwardDiff.value.(u) | |
# you can play with the dimension here, sometimes it makes sense to transpose | |
ps = mapreduce(ForwardDiff.partials, hcat, u) | |
# get f(vs) | |
val = foo(vs) | |
# get J(f, vs) * ps (cheating). Write your custom rule here | |
jvp = ForwardDiff.jacobian(goo, vs) * ps | |
# pack: SoA -> AoS | |
return map(val, eachrow(jvp)) do v, p | |
ForwardDiff.Dual{T}(v, p...) # T is the tag | |
end | |
end | |
ForwardDiff.gradient(u->sum(cumsum(foo(u))), [1, 2, 3]) == ForwardDiff.gradient(u->sum(cumsum(goo(u))), [1, 2, 3]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
It always makes sense to transpose
ps
(or just dostack(ForwardDiff.partials, u; dims=1)
). The current implementation is only correct for the test case because we are finding gradient wrt the argument offoo
directly, in which caseps
is an identity matrix sotranspose(ps) == ps
.