Created
May 14, 2021 10:15
-
-
Save mschauer/2a4ae3c54a2f73f3632d2174087c4206 to your computer and use it in GitHub Desktop.
Compute marginal likelihood for https://gist.github.com/devmotion/37d8d706938364eeb900bc3678860da6
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Mitosis | |
using MitosisStochasticDiffEq | |
import MitosisStochasticDiffEq as MSDE | |
using StaticArrays, LinearAlgebra | |
using OrdinaryDiffEq | |
# Match with B and sigma | |
B(θ) = [-0.1 0.2θ; -0.2θ -0.1] | |
beta(θ) = [0.,0.] | |
Σ(θ) = 0.15*I(2) | |
tildekappa(p, trange) = MSDE.SDEKernel(Mitosis.AffineMap(B(p), beta(p)), Mitosis.ConstantMap(Σ(p)), trange, (B(p), beta(p), Σ(p)) #=plin=#) | |
u = map(Vector, reinterpret(SVector{2, Float64}, Y))[:] | |
κ̃ = tildekappa(θ, t) | |
ς2 = (ς)^2 | |
Σobs = [ς2 0; 0 ς2] | |
messages = Vector{Any}(undef, length(u)) | |
# message, backward = MSDE.backwardfilter(κ̃, G0) | |
function bwfilter!(messages, T, u, p, dt0, Σobs; apply_timechange=false) | |
dt0 = 0.01 | |
Q = WGaussian{(:μ, :Σ, :c)}(one(eltype(p))*u[end], Σobs, 0.0) | |
for i in reverse(eachindex(u)) | |
i == 1 && continue # skip root node (has no parent) | |
ipar = i-1 | |
δ = T[i]-T[ipar] | |
dt = δ/round(Int, δ/dt0) | |
trange = T[ipar]:dt:T[i] | |
κ̃ = tildekappa(p, trange) | |
message, Q = MSDE.backwardfilter(κ̃, Q, alg=OrdinaryDiffEq.Tsit5(), apply_timechange=apply_timechange) | |
messages[i] = message | |
Q = Mitosis.fuse(Q, WGaussian{(:μ, :Σ, :c)}(one(eltype(p))*u[ipar], Σobs, 0.0))[2] | |
end | |
messages, Q | |
end | |
function logevidence(θ) | |
_, evid = @time bwfilter!(messages, obstime, u, θ, dt0, Σobs); | |
Mitosis.logdensity(convert(WGaussian{(:F, :Γ, :c)}, evid), x0) | |
end | |
θr = 0.8:0.05:1.4 | |
post = logevidence.(θr) | |
lines(θr, post) | |
qua = [one.(θr) θr θr.^2/2]\post | |
G = Gaussian{(:F,:Γ)}(qua[2], -qua[3]) | |
println(mean(G), " ± ", sqrt(cov(G)[])) | |
println(mean(θs), " ± ", std(θs)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment