Last active
May 12, 2021 15:32
-
-
Save mschauer/d1b95bc7031eb858e94de9fb86622c75 to your computer and use it in GitHub Desktop.
Elementary sample SDE using forward simulation and random walk on Wiener process ("Crank-Nicolson" scheme).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using AdvancedMH | |
using Distributions | |
using Random | |
using MCMCChains | |
using StructArrays | |
struct MvWiener{Tu0,Ttrange} <: ContinuousMatrixDistribution | |
u0::Tu0 | |
trange::Ttrange | |
end | |
struct MvWienerStep{Twiener,TW,Tρ} <: ContinuousMatrixDistribution | |
wiener::Twiener | |
W::TW | |
ρ::Tρ | |
end | |
struct MvSDE{Tu0,Tb,Tσ,Tθ} <: ContinuousMatrixDistribution | |
u0::Tu0 | |
b::Tb | |
σ::Tσ | |
θ::Tθ | |
end | |
Base.size(w::MvWiener) = (only(size(w.u0)), length(w.trange)) | |
Base.size(w::MvWienerStep) = size(w.wiener) | |
function Distributions._rand!(rng::Random.AbstractRNG, wiener::MvWiener, W::DenseMatrix) | |
size(W) == size(wiener) || throw(DimensionMismatch()) | |
u0 = wiener.u0 | |
trange = wiener.trange | |
randn!(rng, W) | |
sdt = zeros(length(trange)) | |
for i in 2:length(trange) | |
sdt[i] = sqrt(trange[i]-trange[i-1]) | |
end | |
W .*= sdt' | |
cumsum!(W, W, dims=2) | |
W | |
end | |
function Distributions._rand!(rng::Random.AbstractRNG, step::MvWienerStep, W::DenseMatrix) | |
Distributions._rand!(rng, step.wiener, W) | |
W .= step.ρ*step.W .+ sqrt(1 - step.ρ^2)*W | |
W | |
end | |
wiener = MvWiener(zeros(10000), [0.,0.5, 2.0]) | |
W = rand!(wiener, zeros(10000, 3)) | |
W = rand(wiener) | |
Wstep = rand(MvWienerStep(wiener, W, 0.9)) | |
mean(W, dims=1) | |
var(W, dims=1) | |
function solve(P::MvSDE, ts, W) | |
t = ts[1] | |
xs = zeros(only(size(P.u0)),size(W,2)...) | |
xs[:, 1] = x = copy(P.u0) | |
for i in axes(W,2)[2:end] | |
x = x + P.b(t, x, P.θ)*(ts[i]-ts[i-1]) + P.σ(t, x, P.θ)*(W[:,i] - W[:,i-1]) | |
xs[:,i] = x | |
end | |
xs | |
end | |
b(t, x, θ) = [-0.1 θ*0.2; -θ*0.2 -0.1]*x | |
σ(t, x, _) = 0.15 | |
θ0 = 1.0 | |
t = 0:0.05:20.0 | |
wiener = MvWiener(zeros(2), t) | |
W0 = rand(wiener) | |
X = solve(MvSDE([1.0,1.0], b, σ, θ0), t, W0) | |
ς = 0.2 | |
obs = eachindex(t)[1:10:end] | |
Y = X[:,obs] + ς*randn(size(X[:, obs])) | |
W = rand(MvWiener(zeros(2), t)) | |
function density((θ, W)) | |
X = solve(MvSDE([1.0,1.0], b, σ, θ), t, W) | |
-0.5*sum((Y - X[:,obs]).^2)/(ς^2) | |
end | |
density((θ0+0.1,W)) | |
model = DensityModel(density) | |
function FW(W) | |
MvWienerStep(wiener, W, 0.98) | |
end | |
Fθ(x) = Normal(x, 0.05) | |
AdvancedMH.is_symmetric_proposal(::StaticProposal{typeof(FW)}) = true | |
AdvancedMH.is_symmetric_proposal(::StaticProposal{typeof(Fθ)}) = true | |
θprop = StaticProposal(Fθ) | |
Wprop = StaticProposal(FW) # I might complain that StaticProposals take a function! | |
θ = 0.95 | |
# define because is_symmetric_proposal is broken | |
Distributions.logpdf(d::MvWiener, x::AbstractMatrix{Float64}) = 0.0 | |
Distributions.logpdf(d::MvWienerStep, x::AbstractMatrix{Float64}) = 0.0 | |
N = 100000 | |
chain = sample(model, MetropolisHastings([θprop, Wprop]),N; init_params=(θ=θ,W=W), param_names=["θ", "W"], chain_type=Vector{NamedTuple}) | |
θs = first.(chain) | |
@show mean(θs), std(θs) | |
fig = Figure(resolution=(2000,500)) | |
using Colors | |
lines!(Axis(fig[1,1]), θs) | |
lines!(fill(θ0, N), color=:red) | |
lines!(fill(mean(θs[end÷2:end]), N), color=:orange) | |
ax = Axis(fig[1,2]) | |
WM = mean(getindex.(chain[end÷2:end], :W)) | |
lines!(ax, W0[1,:],color=:red, linewidth=3) | |
lines!(ax, WM[1,:], linewidth=3, color=:orange) | |
lines!(ax, W0[2,:],color=:red) | |
lines!(ax, WM[2,:], color=:orange) | |
ax = fig[1,3] = Axis(fig) | |
lines!(ax, W0, color=:red) | |
for i in reverse(1:5000:N) | |
lines!(ax, chain[i].W, color=fill(i, length(t)), colorrange=(1,N)) | |
end | |
WT = chain[end].W | |
lines!(ax, W0, color=:red, linewidth= 4.) | |
lines!(ax, WM , color=:orange, linewidth=4.0) # posterior latent mean | |
fig | |
save("fig.png", fig) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment