Created
April 9, 2020 01:00
-
-
Save baggepinnen/5d1e41bfbbdde4e7104266085f716d28 to your computer and use it in GitHub Desktop.
Soss Markov Chain model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Soss | |
struct MarkovChain{P,D} | |
pars :: P | |
dist :: D | |
end | |
function Distributions.logpdf(chain::MarkovChain, x::AbstractVector{X}) where {X} | |
@inbounds x1 = (pars=chain.pars,state=x[1]) | |
length(x) == 1 && return logpdf(chain.dist, x1.state) | |
chain_next = MarkovChain(chain.pars, chain.dist.model(x1)) | |
v = @inbounds @view x[2:end] | |
result = logpdf(chain.dist, x1.state) | |
chain = chain_next | |
x = v | |
return result + logpdf(chain,x) | |
end | |
model = Soss.@model p begin | |
σI ~ Gamma(100) # These are all bad choices and have to be tuned with prior predictive checks | |
σIr ~ Gamma(50) | |
σR ~ truncated(Normal(0,0.03), 1e-6, Inf) # Random walk for R | |
Ir ~ truncated(Normal(1,1), 0, Inf) | |
I ~ Gamma(10) | |
R ~ Normal(3,0.5) # A guess | |
q ~ Uniform(0,1) | |
s0 = @namedtuple(I,Ir,R) # The initial state | |
pars = @namedtuple(σI, σIr, σR, p, q) # Parameters stay constant between interations | |
x ~ MarkovChain(pars, mstep(pars=pars, state=s0)) | |
end | |
mstep = Soss.@model pars,state begin | |
# Parameters | |
σI = pars.σI | |
σIr = pars.σIr | |
σR = pars.σR | |
p = pars.p | |
q = pars.q | |
# Starting counts | |
I0 = state.I | |
Ir0 = state.Ir | |
R0 = state.R | |
# Transitions between states | |
sR ~ Normal(R0, σR) | |
sI ~ Normal(R0 * sum(p[i]*I0 for i = 1:min(np,k-1)), σI) | |
# sIr ~ Normal(sum(q[i]*I0 for i = 1:nq), σIr) | |
sIr ~ Normal(q*I0, σIr) | |
# Updated counts | |
I = I0 + sI | |
R = R0 + sR | |
Ir = Ir0 + sIr | |
end; | |
fake_Ir_data = round.(exp.(0.2 .* (1:30))) | |
data = (x=[(Ir=i,) for i in fake_Ir_data],) | |
p = rand(Dirichlet(4, 1)) # A random vector that sums to 1 | |
pm = model(p=p, x=data.x) | |
prior = rand(m, 10) | |
m = model(p=p) | |
post = dynamicHMC(m, data, 100) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment