Skip to content

Instantly share code, notes, and snippets.

@baggepinnen
Created April 9, 2020 01:00
Show Gist options
  • Save baggepinnen/5d1e41bfbbdde4e7104266085f716d28 to your computer and use it in GitHub Desktop.
Save baggepinnen/5d1e41bfbbdde4e7104266085f716d28 to your computer and use it in GitHub Desktop.
Soss Markov Chain model
using Soss
struct MarkovChain{P,D}
pars :: P
dist :: D
end
function Distributions.logpdf(chain::MarkovChain, x::AbstractVector{X}) where {X}
@inbounds x1 = (pars=chain.pars,state=x[1])
length(x) == 1 && return logpdf(chain.dist, x1.state)
chain_next = MarkovChain(chain.pars, chain.dist.model(x1))
v = @inbounds @view x[2:end]
result = logpdf(chain.dist, x1.state)
chain = chain_next
x = v
return result + logpdf(chain,x)
end
model = Soss.@model p begin
σI ~ Gamma(100) # These are all bad choices and have to be tuned with prior predictive checks
σIr ~ Gamma(50)
σR ~ truncated(Normal(0,0.03), 1e-6, Inf) # Random walk for R
Ir ~ truncated(Normal(1,1), 0, Inf)
I ~ Gamma(10)
R ~ Normal(3,0.5) # A guess
q ~ Uniform(0,1)
s0 = @namedtuple(I,Ir,R) # The initial state
pars = @namedtuple(σI, σIr, σR, p, q) # Parameters stay constant between interations
x ~ MarkovChain(pars, mstep(pars=pars, state=s0))
end
mstep = Soss.@model pars,state begin
# Parameters
σI = pars.σI
σIr = pars.σIr
σR = pars.σR
p = pars.p
q = pars.q
# Starting counts
I0 = state.I
Ir0 = state.Ir
R0 = state.R
# Transitions between states
sR ~ Normal(R0, σR)
sI ~ Normal(R0 * sum(p[i]*I0 for i = 1:min(np,k-1)), σI)
# sIr ~ Normal(sum(q[i]*I0 for i = 1:nq), σIr)
sIr ~ Normal(q*I0, σIr)
# Updated counts
I = I0 + sI
R = R0 + sR
Ir = Ir0 + sIr
end;
fake_Ir_data = round.(exp.(0.2 .* (1:30)))
data = (x=[(Ir=i,) for i in fake_Ir_data],)
p = rand(Dirichlet(4, 1)) # A random vector that sums to 1
pm = model(p=p, x=data.x)
prior = rand(m, 10)
m = model(p=p)
post = dynamicHMC(m, data, 100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment