Created
August 8, 2019 19:02
-
-
Save oxinabox/d6e07dc58b1e0b10b5e15b23a5b0346a to your computer and use it in GitHub Desktop.
A Simple scalar ForwardDiff using ChainRules + DualNumbers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#== | |
A Simple scalar ForwardDiff using ChainRules + DualNumbers | |
--- | |
No promises are made to its correctness or safty. | |
Infact it probably errors for super standard cases. | |
But this is just to explain how it can work | |
==# | |
## Setup | |
using Pkg: Pkg, @pkg_str | |
Pkg.activate(@__DIR__) | |
pkg"add ChainRules" | |
pkg"add DualNumbers" | |
pkg"add Cassette" | |
## Main Code | |
using DualNumbers | |
using ChainRules | |
using Cassette | |
Cassette.@context DiffCtx2 | |
const diffctx2 = DiffCtx2() | |
Cassette.overdub(::DiffCtx, f, args...) = duel_based_grad(f, args...) | |
function duel_based_grad(f, x...) | |
@show f | |
xr = realpart.(x) | |
xd = dualpart.(x) | |
rule = ChainRules.frule(f, xr...) | |
if rule === nothing | |
@show "no rule" | |
# No rule, need to do nontrival AD | |
# x has duel parts, so calling it does AD | |
# Would just do `y = f(x...)`, but want to substitute for | |
# contained calls so need to do a Cassette.recurse | |
y = Cassette.recurse(diffctx2, f, x...) | |
@show y | |
yr = realpart.(y) | |
∂y = dualpart.(y) | |
yd = xd.*∂y # is this math right? | |
else | |
@show "hit rule" | |
yr, yd_rule = rule | |
yd = yd_rule(xd...) | |
end | |
@show yr | |
@show yd | |
return Dual.(yr, yd) | |
end | |
function grad(f, x) | |
println("-"^40) | |
y = duel_based_grad(f, Dual(x, 1.0)) | |
@assert length(y)==1 | |
y = first(y) | |
return (;result=realpart(y), derivative=dualpart(y)) | |
end | |
## Demo | |
@show grad(-, 20.1) | |
#== Output | |
grad(-, 20.1) = (result = -20.1, derivative = -1.0) | |
==# | |
@show grad(x->3*x, 20.1) | |
#== Output | |
grad(x->3x, 20.1) = (result = 60.300000000000004, derivative = 3.0) | |
==# |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment