Skip to content

Instantly share code, notes, and snippets.

@torfjelde
Last active September 18, 2020 11:20
Show Gist options
  • Save torfjelde/d1888cd3b1dcff0aad560d5c2f03cf8b to your computer and use it in GitHub Desktop.
Save torfjelde/d1888cd3b1dcff0aad560d5c2f03cf8b to your computer and use it in GitHub Desktop.
Example of writing out `Turing.sample` by hand.
julia> using Random
julia> using Turing
julia> rng = MersenneTwister(42);
julia> @model function demo(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, √s)
for i in eachindex(x)
x[i] ~ Normal(m, √s)
end
end
demo (generic function with 1 method)
julia> xs = randn(100) .+ 1;
julia> model = demo(xs);
julia> n = 100 # number of MCMC steps to take
100
julia> alg = NUTS(0.65) # sampling algorithm to use
NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.DiagEuclideanMetric}(-1, 0.65, 10, 1000.0, 0.0)
julia> spl = DynamicPPL.Sampler(alg, model);
┌ Info: Found initial step size
└ ϵ = 0.8
julia> # 1. Initialize the sampler (will also do adaptation)
AbstractMCMC.sample_init!(rng, model, spl, n);
julia> # 2. Take the first step
transition = AbstractMCMC.step!(rng, model, spl, n)
Turing.Inference.HamiltonianTransition{NamedTuple{(:s, :m),Tuple{Tuple{Array{Float64,1},Array{String,1}},Tuple{Array{Float64,1},Array{String,1}}}},NamedTuple{(:n_steps, :is_accept, :acceptance_rate, :log_density, :hamiltonian_energy, :hamiltonian_energy_error, :max_hamiltonian_energy_error, :tree_depth, :numerical_error, :step_size, :nom_step_size),Tuple{Int64,Bool,Float64,Float64,Float64,Float64,Float64,Int64,Bool,Float64,Float64}},Float64}((s = ([1.1419439896329209], ["s"]), m = ([-0.18388345765143033], ["m"])), -207.3741582911855, (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -207.3741582911855, hamiltonian_energy = 209.6099530918558, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1646.4591248230568, tree_depth = 0, numerical_error = true, step_size = 0.8, nom_step_size = 0.8))
julia> # 3. Initialize container to hold transitions
transitions = AbstractMCMC.transitions_init(transition, model, spl, n);
julia> # 4. Add the first step taken
AbstractMCMC.transitions_save!(transitions, 1, transition, model, spl, n);
julia> # `transition.θ` holds the sampled parameters, `DynamicPPL.getlogp(transition)` holds the logjoint.
# Step `n` times
for i = 2:n
transition = AbstractMCMC.step!(rng, model, spl, n, transition)
# Save the transition (`i` represents step-index here)
AbstractMCMC.transitions_save!(transitions, i, transition, model, spl, n)
# Do whatever you want with the `transition`, e.g. plotting it
# ...
end
julia> # 5. Bundle the transitions into a `Chains` object
n_without_adapt = n - alg.n_adapts # we've discared the adaptation samples, so need to recompute the number of samples
50
julia> chain = AbstractMCMC.bundle_samples(rng, model, spl, n_without_adapt, transitions, Vector{NamedTuple})
50-element Array{NamedTuple{(:s, :m, :lp),Tuple{Array{Float64,1},Array{Float64,1},Float64}},1}:
(s = [1.507184732983516], m = [1.0356258671542418], lp = -151.41988024461492)
(s = [1.484206208007892], m = [1.1602274966417248], lp = -152.36295261917112)
(s = [0.9338965104551517], m = [0.7631141910005066], lp = -152.2794729888841)
(s = [0.8748130604568505], m = [0.7561100961393935], lp = -153.42017856264803)
(s = [0.8919208735374546], m = [0.9603554746508606], lp = -150.6681551475251)
(s = [1.429356650688761], m = [0.9582921877494085], lp = -150.59036772786897)
(s = [1.2349684890768557], m = [0.8203860117396458], lp = -150.23810256786362)
(s = [1.0145019439666316], m = [0.7859877810935259], lp = -151.01138091940277)
(s = [1.2810444611081253], m = [1.0173312089782505], lp = -149.72555269605294)
(s = [1.2051416011067522], m = [0.9800854825710026], lp = -149.3122555263368)
(s = [1.0596312001963288], m = [1.0609548061900078], lp = -149.71309016061124)
(s = [0.9031732383433777], m = [1.1122144234802736], lp = -151.74321150147227)
(s = [1.5156105561494102], m = [0.9605394339230326], lp = -151.32009692441935)
(s = [1.4616487524950548], m = [0.8989587969302851], lp = -150.99463104071495)
(s = [1.1924313418138075], m = [1.1696287720182417], lp = -151.06823704003565)
(s = [1.352386448077], m = [0.8340193398214966], lp = -150.65327793568161)
(s = [1.3252316313731927], m = [0.7123487656669], lp = -152.25655886123408)
(s = [1.1407953978522807], m = [1.1420370503599286], lp = -150.59180719495984)
(s = [0.9631951700812363], m = [1.0635913604155043], lp = -150.34335070484224)
(s = [1.453684821093764], m = [0.8069315441649687], lp = -151.63485364680574)
(s = [0.8355621315964834], m = [0.8810537594186493], lp = -152.10498798819356)
(s = [1.5521578595590961], m = [0.8750421882165322], lp = -151.90876769025115)
(s = [1.3094190193437958], m = [0.8654561264445818], lp = -150.13292869776274)
(s = [0.9556931717777034], m = [1.0635578474016454], lp = -150.41679517046654)
(s = [1.125334749308243], m = [0.9923144066378371], lp = -149.20822932783676)
(s = [1.0428471943474957], m = [1.068439648862657], lp = -149.8505818346328)
(s = [1.3331010523696905], m = [0.8226790999089237], lp = -150.65441949177156)
(s = [1.3293518529741088], m = [1.238101167642786], lp = -152.75163021486273)
(s = [1.0282435867240292], m = [1.0202074299162287], lp = -149.5362203903304)
(s = [0.939837577906819], m = [1.0109527322570468], lp = -150.169271997404)
(s = [1.2611754520664036], m = [0.9158934282556855], lp = -149.60417423866926)
(s = [0.9373430421673838], m = [1.1338176604774994], lp = -151.64139017823854)
(s = [0.9373430421673838], m = [1.1338176604774994], lp = -151.64139017823854)
(s = [1.2848043883870464], m = [0.8731898496862049], lp = -149.9486140932034)
(s = [1.2848043883870464], m = [0.8731898496862049], lp = -149.9486140932034)
(s = [1.078208798318388], m = [1.0263183197207986], lp = -149.39895718480605)
(s = [1.045233519177509], m = [1.0769694379652777], lp = -149.93056561657153)
(s = [1.208585591727083], m = [0.7837725422847548], lp = -150.65726332462089)
(s = [1.209345867180395], m = [0.9390242477552445], lp = -149.3378459846963)
(s = [0.9585165680493164], m = [1.034252277435292], lp = -150.124087258444)
(s = [0.9231667013759385], m = [1.0681605344213074], lp = -150.8413702898093)
(s = [1.007486102189982], m = [1.0576800676379319], lp = -149.9336293836161)
(s = [1.0330965312157063], m = [1.1534325811297004], lp = -151.1226386004739)
(s = [0.9213584889683054], m = [1.1448078545099563], lp = -152.06813143725682)
(s = [0.9213584889683054], m = [1.1448078545099563], lp = -152.06813143725682)
(s = [0.922586460156493], m = [1.1124397315282506], lp = -151.4646269775953)
(s = [1.0005355580905382], m = [0.9466019426759983], lp = -149.54373716403447)
(s = [1.2826795056371791], m = [0.8458701431333486], lp = -150.16138613045362)
(s = [0.981483126455662], m = [1.0794272367207176], lp = -150.35759914620317)
(s = [1.3454763130648564], m = [0.8590214936522462], lp = -150.3921587244739)
julia> # Or if you want a `MCMCChains.Chains` object instead:
chain = AbstractMCMC.bundle_samples(rng, model, spl, n_without_adapt, transitions, MCMCChains.Chains)
Chains MCMC chain (50×14×1 Array{Float64,3}):
Iterations = 1:50
Thinning interval = 1
Chains = 1
Samples per chain = 50
parameters = m, s
internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Missing Float64 Float64
m 0.9808 0.1319 0.0187 missing 39.1866 1.0475
s 1.1444 0.2050 0.0290 missing 41.8236 1.0566
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
m 0.7577 0.8732 1.0016 1.0748 1.1675
s 0.8787 0.9564 1.1018 1.3033 1.5137
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment