Created
March 27, 2016 00:23
-
-
Save axsk/7cbffdbc2c9b7dae1077 to your computer and use it in GitHub Desktop.
cvodes-autodiff
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Sundials, ForwardDiff | |
import Sundials: realtype, N_Vector | |
type FAndP | |
f::Function | |
p::Vector{Float64} | |
end | |
function unzip(fp::Ptr{Void}) | |
fp = unsafe_pointer_to_objref(fp) :: FAndP | |
(fp.f, fp.p) | |
end | |
function cvodesfun(t, y, dy, fp) | |
y = Sundials.asarray(y) | |
dy = Sundials.asarray(dy) | |
f,p = unzip(fp) | |
f(t, y, p, dy) | |
return Int32(0) | |
end | |
function differentiator(f,ny,np) | |
y = Vector(ny) | |
p = Vector(np) | |
dy = Vector(ny) | |
J = Matrix{Float64}(ny, ny+np) | |
function merged(x) | |
y[:]=x[1:ny] | |
p[:]=x[ny+(1:np)] | |
f(0,y,p,dy) # TODO fix time dependence | |
dy | |
end | |
j! = ForwardDiff.jacobian(merged, mutates=true) | |
(y0,p) -> j!(J, vcat(y0,p)) | |
end | |
function sensrhsfn(ns::Int32, t::realtype, y::N_Vector, ydot::N_Vector, yS::N_Vector, ySdot::N_Vector, user_data::Ptr{Void}, tmp1::N_Vector, tmp2::N_Vector) | |
#@show Sundials.asarray(yS) | |
@show unsafe_load(yS,1) |> pointer |> Sundials.asarray | |
np = ns | |
f, p = unzip(user_data) | |
y = Sundials.asarray(y) | |
yS = pointer_to_array(yS, np) | |
ySdot = pointer_to_array(ySdot, np) | |
@show typeof(ySdot) | |
@show Sundials.asarray(yS[1]) | |
#@show yS[1] | |
D!(y, p) | |
for i in 1:np | |
ySi = Sundials.asarray(yS[i]) | |
ySdot[i] = Sundials.nvector(J[:,1:ny] * ySi + J[:,ny+i]) | |
end | |
return Int32(0) | |
end | |
# expect function signature f=f!(t, y0, p, dy) | |
function cvodes(f::Function, y0::Vector{Float64}, p::Vector{Float64}, ts::Vector{Float64}; reltol=1e-8, abstol=1e-6, autodiff=true) | |
ny = length(y0) | |
np = length(p) | |
### Initialize automatic differentiator | |
D! = differentiator(f, ny, np) | |
sensrhsfnptr = cfunction(sensrhsfn, Int32, (Int32, realtype, N_Vector, N_Vector, N_Vector, N_Vector, Ptr{Void}, N_Vector, N_Vector)) | |
### CVode settings ### | |
cvode_mem = Sundials.CVodeCreate(Sundials.CV_BDF, Sundials.CV_NEWTON) | |
Sundials.CVodeInit(cvode_mem, cvodesfun, ts[1], y0) | |
Sundials.CVodeSetUserData(cvode_mem, FAndP(f,p)) | |
Sundials.CVodeSStolerances(cvode_mem, reltol, abstol) | |
Sundials.CVDense(cvode_mem, ny) | |
### Sensiviy Settings ### | |
yS = [Sundials.nvector(zeros(Float64, ny)) for i in 1:np] |> pointer | |
if autodiff | |
Sundials.CVodeSensInit(cvode_mem, np, Sundials.CV_SIMULTANEOUS, sensrhsfn, yS); | |
else | |
Sundials.CVodeSensInit(cvode_mem, np, Sundials.CV_SIMULTANEOUS, Ptr{Void}(0), yS); | |
Sundials.CVodeSetSensDQMethod(cvode_mem, Sundials.CV_CENTERED, 0.0); | |
end | |
Sundials.CVodeSetSensParams(cvode_mem, p, p, Ptr{Int32}(0)); | |
Sundials.CVodeSetSensErrCon(cvode_mem, 0); | |
Sundials.CVodeSensEEtolerances(cvode_mem); | |
#Sundials.CVodeSensSStolerances(cvode_mem, reltol, sens_tol_vec); | |
# Placeholder for solution and sensitivities | |
solution = zeros(length(ts), ny) | |
solution[1,:] = copy(y0) | |
sens = zeros(length(ts),ny,np) # No need to copy initial condition, they are already zero | |
tout = [0.] # output time reached by the solver | |
yout = copy(y0) | |
# Loop through all the output times | |
for k in 2:length(ts) | |
# Extract the solution to x, and the sensitivities to yS | |
Sundials.CVode(cvode_mem, ts[k], yout, tout, Sundials.CV_NORMAL) | |
Sundials.CVodeGetSens(cvode_mem, tout, yS) | |
#Save the results | |
solution[k,:] = yout | |
for i in 1:np | |
sens[k,:,i] = Sundials.asarray(unsafe_load(yS,i)) | |
end | |
end | |
return (solution,sens) | |
end | |
function f(t,y,p,dy) | |
dy[1] = p[1] | |
dy[2] = p[2] | |
end | |
cvodes(f, [.5,0], [1,2.], collect(linspace(0,1,10)), autodiff = true) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment