Skip to content

Instantly share code, notes, and snippets.

@axsk
Last active August 29, 2022 18:39
Show Gist options
  • Save axsk/979f75274cb564c9c18ebf5bb00b3c4b to your computer and use it in GitHub Desktop.
Save axsk/979f75274cb564c9c18ebf5bb00b3c4b to your computer and use it in GitHub Desktop.
#https://github.com/FluxML/Zygote.jl/issues/1294
using Zygote
using StochasticDiffEq, SciMLSensitivity
using StatsBase
import Lux
function mwe()
x0 = rand(1)
p0 = rand(1)
drift(du,u,p,t) = (du .= 1)
noise(du,u,p,t) = (du .= 1)
prob = SDEProblem(drift, noise, x0, 1., p0)
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())
Zygote.gradient(p0) do p
var(solve(Zygote.@showgrad(remake(prob, p=p)), EM(), dt=.1, sensealg=sensealg)[end][1] for i in 1:3)
end
end
mwe()
@axsk
Copy link
Author

axsk commented Aug 29, 2022

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment