Created
April 8, 2021 18:22
-
-
Save mschauer/ddcbf15aba6023e2395d677a683a2646 to your computer and use it in GitHub Desktop.
zigzag on Innovation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using ZigZagBoomerang | |
using SparseArrays | |
using ForwardDiff | |
using ForwardDiff: Dual, value, partials | |
const D1 = Dual{Nothing, Float64, 1} | |
const D𝕏 = typeof(D1.(zero(𝕏))) | |
function partiali(F, z, i) | |
z[i] = Dual(z[i], 1.0) | |
r = partials(F(z))[] | |
z[i] = Dual(z[i], 0.0) | |
r | |
end | |
function dfwguidtree!(X, guidedsegs, messages, tree::Tree, f, g, θ, Z; apply_time_change=false) | |
ll = zeros(D1, tree.n) | |
for i in eachindex(tree.T) | |
i == 1 && continue # skip root-node (has no parent) | |
κ = MSDE.SDEKernel(f, g, messages[i].ts, θ) | |
ipar = tree.Par[i] | |
solfw, llnew = MSDE.forwardguiding(κ, messages[i], (X[ipar], 0.0), Z[i-1]; inplace=false, save_noise=true, apply_timechange=apply_time_change) | |
ll[i] = llnew + ll[ipar] * tree.lastone[i] | |
X[i] = D𝕏(solfw[end][1:end-1]) | |
guidedsegs[i] = solfw | |
end | |
𝐋 = sum(ll[tree.lids]) | |
X, guidedsegs, ll, 𝐋 | |
end | |
θinit = [4.5, 0.1], σ0 | |
iters = 5_000 | |
#θs, guidedsegs, frac_accepted = mcmc(tree, Xd, f, g, θlin, θinit; iters=iters) | |
σprop = 0.05 | |
precisionatleaves=1e-4 | |
dt = 0.01 | |
Q = [i in tree.lids ? WGaussian{(:μ,:Σ,:c)}(Vector(Xd[i]), precisionatleaves*Matrix(I(d)), 0.0) : missing for i in tree.ids] | |
Q, messages = bwfiltertree!(Q, tree, θlin, dt) | |
f(u,θ,t) = SVector((tanh.(Diagonal(θ[1]) * M * u))...) # f(u,θ,t) = Diagonal(θ[1]) * M * u | |
guidedsegs = Vector{Any}(undef, tree.n) # save all guided segments | |
X = zeros(D𝕏, tree.n) # values at nodes | |
for id in tree.lids | |
X[id] = Xd[id] | |
end | |
struct VNoise{T,R,S} <: AbstractVector{T} | |
α::R | |
Z::S | |
end | |
Base.size(z::VNoise) = (length(z.α),) | |
VNoise(::Type{T}, α::R, Z::S) where {T,R,S} = VNoise{T, R, S}(α, Z) | |
Base.copy(z::VNoise) = deepcopy(z) | |
Base.getindex(z::VNoise, i) = value(z.Z[z.α[i][1]][z.α[i][2]][z.α[i][3]]) | |
function Base.setindex!(z::VNoise, x, i) | |
u = z.Z[z.α[i][1]][z.α[i][2]] | |
z.Z[z.α[i][1]][z.α[i][2]] = setindex(u, x, z.α[i][3]) | |
x | |
end | |
function dinnov(t) | |
dt = diff(t) | |
w = [sqrt(dt[i])*randn(𝕏_) for i in 1:length(t)-1] | |
brownian_values = cumsum(pushfirst!(w, zero(𝕏_))) | |
myNoiseGrid(t, map(x->D1.(x), brownian_values)) | |
end | |
Z = [dinnov(messages[i].ts) for i ∈ 2:tree.n] | |
α = [(i,j,k) for k in 1:2 for i in eachindex(Z) for j in eachindex(Z[i]) ] | |
z = VNoise(Float64, α, Z) | |
F(z) = dfwguidtree!(X, guidedsegs, messages, tree, f, g, θ0, z.Z)[4] | |
∇ϕi(z,i) = sqrt(2)/dt*(z[i] - ((i > 1) && 0.5z[i-1]) - ((i < N) && 0.5z[i+1])) - partiali(F, z, i) | |
t0 = 0.0 | |
x0 = deepcopy(z) | |
N = length(x0) | |
v0 = rand((-1.0,1.0), N) | |
c = 1.0 | |
T = 30. | |
Γ = sparse(SymTridiagonal(2.0ones(N), -0.9999ones(N-1))) | |
# Note the covariance matrix of B.M is inv([min(i,j) for i in ts, j in ts]) | |
# the inverse of the first element is 1 + 1/t | |
Γ[1,1] = 1 + 1/dt | |
Γ[end,end] = 1 | |
#sparse(1.0*I(N)) | |
tr, _, (acc, num), cs = @time ZigZagBoomerang.spdmp(∇ϕi, t0, x0, v0, T, c*ones(N), ZigZag(Γ, 0*x0); adapt=true, structured=false) | |
ts, xs = ZigZagBoomerang.sep(discretise(tr, T/100)) | |
guidedZ = xs[end].Z | |
X, guidedsegs, ll, 𝐋 = dfwguidtree!(X, guidedsegs, messages, tree, f, g, θ0, guidedZ) | |
cols = repeat([:blue, :red, :magenta, :orange],2) | |
pl1 = scatter(tree.T[tree.lids], getindex.(Xd,1)[tree.lids],color=:black,legend=false) | |
pl2 = scatter(tree.T[tree.lids], getindex.(Xd,2)[tree.lids],color=:black,legend=false) | |
for i in 2:tree.n | |
global gg | |
gg = guidedsegs[i] | |
col = sample(cols) | |
plot!(pl1, gg.t, value.(getindex.(gg.u,1)),color=col) | |
plot!(pl2, gg.t, value.(getindex.(gg.u,2)),color=col) | |
gg = guidedZ[i-1] | |
# plot!(pl1, gg.t, value.(getindex.(gg.u,1)),color=:grey) | |
# plot!(pl2, gg.t, value.(getindex.(gg.u,2)),color=:grey) | |
end | |
pl = plot(pl1, pl2, layout=(2,1), legend=false) | |
display(pl) | |
png("zigzag.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment