Skip to content

Instantly share code, notes, and snippets.

@axsk
Last active August 29, 2022 17:42
Show Gist options
  • Save axsk/d1fb3a09e7c29bbb0e6cd3a0dd4fe00a to your computer and use it in GitHub Desktop.
Save axsk/d1fb3a09e7c29bbb0e6cd3a0dd4fe00a to your computer and use it in GitHub Desktop.
MWE differentiate callbacks
# https://github.com/SciML/SciMLSensitivity.jl/issues/720
using Zygote
using StochasticDiffEq, SciMLSensitivity
function mwe1()
x0 = [0.]
drift(dx, x, p, t) = (dx .= p)
noise(dx, x, p, t) = (dx .= 0.)
n0 = zeros(1,1)
T = 100.
p0 = [1.]
cb = ContinuousCallback((u,t,int)->(u[1]-1), terminate!, nothing)
prob = SDEProblem(drift, noise, x0, T, p0, noise_rate_prototype = n0, callback=cb)
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(), noisemixing=true)
Zygote.gradient(p0) do ps
solve(prob, EM(), dt=0.1, p=ps, sensealg=sensealg)[end][1]
end
end
# fails on about 50%
function _mwe2()
x0 = [0.]
drift(dx, x, p, t) = (dx .= p)
noise(dx, x, p, t) = (dx .= .1)
n0 = zeros(1,1)
T = 100.
p0 = [1.]
cb = ContinuousCallback((u,t,int)->(u[1]-1), terminate!)
prob = SDEProblem(drift, noise, x0, T, p0, noise_rate_prototype = n0, callback=cb)
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(), noisemixing=true)
Zygote.gradient(p0) do ps
solve(prob, EM(), dt=0.1, p=ps, sensealg=sensealg)[end][1]
end
end
function mwe2()
for i in 1:100
_mwe2()
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment