Last active
September 18, 2020 11:20
-
-
Save torfjelde/d1888cd3b1dcff0aad560d5c2f03cf8b to your computer and use it in GitHub Desktop.
Example of writing out `Turing.sample` by hand.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
julia> using Random | |
julia> using Turing | |
julia> rng = MersenneTwister(42); | |
julia> @model function demo(x) | |
s ~ InverseGamma(2, 3) | |
m ~ Normal(0, √s) | |
for i in eachindex(x) | |
x[i] ~ Normal(m, √s) | |
end | |
end | |
demo (generic function with 1 method) | |
julia> xs = randn(100) .+ 1; | |
julia> model = demo(xs); | |
julia> n = 100 # number of MCMC steps to take | |
100 | |
julia> alg = NUTS(0.65) # sampling algorithm to use | |
NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.DiagEuclideanMetric}(-1, 0.65, 10, 1000.0, 0.0) | |
julia> spl = DynamicPPL.Sampler(alg, model); | |
┌ Info: Found initial step size | |
└ ϵ = 0.8 | |
julia> # 1. Initialize the sampler (will also do adaptation) | |
AbstractMCMC.sample_init!(rng, model, spl, n); | |
julia> # 2. Take the first step | |
transition = AbstractMCMC.step!(rng, model, spl, n) | |
Turing.Inference.HamiltonianTransition{NamedTuple{(:s, :m),Tuple{Tuple{Array{Float64,1},Array{String,1}},Tuple{Array{Float64,1},Array{String,1}}}},NamedTuple{(:n_steps, :is_accept, :acceptance_rate, :log_density, :hamiltonian_energy, :hamiltonian_energy_error, :max_hamiltonian_energy_error, :tree_depth, :numerical_error, :step_size, :nom_step_size),Tuple{Int64,Bool,Float64,Float64,Float64,Float64,Float64,Int64,Bool,Float64,Float64}},Float64}((s = ([1.1419439896329209], ["s"]), m = ([-0.18388345765143033], ["m"])), -207.3741582911855, (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -207.3741582911855, hamiltonian_energy = 209.6099530918558, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1646.4591248230568, tree_depth = 0, numerical_error = true, step_size = 0.8, nom_step_size = 0.8)) | |
julia> # 3. Initialize container to hold transitions | |
transitions = AbstractMCMC.transitions_init(transition, model, spl, n); | |
julia> # 4. Add the first step taken | |
AbstractMCMC.transitions_save!(transitions, 1, transition, model, spl, n); | |
julia> # `transition.θ` holds the sampled parameters, `DynamicPPL.getlogp(transition)` holds the logjoint. | |
# Step `n` times | |
for i = 2:n | |
transition = AbstractMCMC.step!(rng, model, spl, n, transition) | |
# Save the transition (`i` represents step-index here) | |
AbstractMCMC.transitions_save!(transitions, i, transition, model, spl, n) | |
# Do whatever you want with the `transition`, e.g. plotting it | |
# ... | |
end | |
julia> # 5. Bundle the transitions into a `Chains` object | |
n_without_adapt = n - alg.n_adapts # we've discared the adaptation samples, so need to recompute the number of samples | |
50 | |
julia> chain = AbstractMCMC.bundle_samples(rng, model, spl, n_without_adapt, transitions, Vector{NamedTuple}) | |
50-element Array{NamedTuple{(:s, :m, :lp),Tuple{Array{Float64,1},Array{Float64,1},Float64}},1}: | |
(s = [1.507184732983516], m = [1.0356258671542418], lp = -151.41988024461492) | |
(s = [1.484206208007892], m = [1.1602274966417248], lp = -152.36295261917112) | |
(s = [0.9338965104551517], m = [0.7631141910005066], lp = -152.2794729888841) | |
(s = [0.8748130604568505], m = [0.7561100961393935], lp = -153.42017856264803) | |
(s = [0.8919208735374546], m = [0.9603554746508606], lp = -150.6681551475251) | |
(s = [1.429356650688761], m = [0.9582921877494085], lp = -150.59036772786897) | |
(s = [1.2349684890768557], m = [0.8203860117396458], lp = -150.23810256786362) | |
(s = [1.0145019439666316], m = [0.7859877810935259], lp = -151.01138091940277) | |
(s = [1.2810444611081253], m = [1.0173312089782505], lp = -149.72555269605294) | |
(s = [1.2051416011067522], m = [0.9800854825710026], lp = -149.3122555263368) | |
(s = [1.0596312001963288], m = [1.0609548061900078], lp = -149.71309016061124) | |
(s = [0.9031732383433777], m = [1.1122144234802736], lp = -151.74321150147227) | |
(s = [1.5156105561494102], m = [0.9605394339230326], lp = -151.32009692441935) | |
(s = [1.4616487524950548], m = [0.8989587969302851], lp = -150.99463104071495) | |
(s = [1.1924313418138075], m = [1.1696287720182417], lp = -151.06823704003565) | |
(s = [1.352386448077], m = [0.8340193398214966], lp = -150.65327793568161) | |
(s = [1.3252316313731927], m = [0.7123487656669], lp = -152.25655886123408) | |
(s = [1.1407953978522807], m = [1.1420370503599286], lp = -150.59180719495984) | |
(s = [0.9631951700812363], m = [1.0635913604155043], lp = -150.34335070484224) | |
(s = [1.453684821093764], m = [0.8069315441649687], lp = -151.63485364680574) | |
(s = [0.8355621315964834], m = [0.8810537594186493], lp = -152.10498798819356) | |
(s = [1.5521578595590961], m = [0.8750421882165322], lp = -151.90876769025115) | |
(s = [1.3094190193437958], m = [0.8654561264445818], lp = -150.13292869776274) | |
(s = [0.9556931717777034], m = [1.0635578474016454], lp = -150.41679517046654) | |
(s = [1.125334749308243], m = [0.9923144066378371], lp = -149.20822932783676) | |
(s = [1.0428471943474957], m = [1.068439648862657], lp = -149.8505818346328) | |
(s = [1.3331010523696905], m = [0.8226790999089237], lp = -150.65441949177156) | |
(s = [1.3293518529741088], m = [1.238101167642786], lp = -152.75163021486273) | |
(s = [1.0282435867240292], m = [1.0202074299162287], lp = -149.5362203903304) | |
(s = [0.939837577906819], m = [1.0109527322570468], lp = -150.169271997404) | |
(s = [1.2611754520664036], m = [0.9158934282556855], lp = -149.60417423866926) | |
(s = [0.9373430421673838], m = [1.1338176604774994], lp = -151.64139017823854) | |
(s = [0.9373430421673838], m = [1.1338176604774994], lp = -151.64139017823854) | |
(s = [1.2848043883870464], m = [0.8731898496862049], lp = -149.9486140932034) | |
(s = [1.2848043883870464], m = [0.8731898496862049], lp = -149.9486140932034) | |
(s = [1.078208798318388], m = [1.0263183197207986], lp = -149.39895718480605) | |
(s = [1.045233519177509], m = [1.0769694379652777], lp = -149.93056561657153) | |
(s = [1.208585591727083], m = [0.7837725422847548], lp = -150.65726332462089) | |
(s = [1.209345867180395], m = [0.9390242477552445], lp = -149.3378459846963) | |
(s = [0.9585165680493164], m = [1.034252277435292], lp = -150.124087258444) | |
(s = [0.9231667013759385], m = [1.0681605344213074], lp = -150.8413702898093) | |
(s = [1.007486102189982], m = [1.0576800676379319], lp = -149.9336293836161) | |
(s = [1.0330965312157063], m = [1.1534325811297004], lp = -151.1226386004739) | |
(s = [0.9213584889683054], m = [1.1448078545099563], lp = -152.06813143725682) | |
(s = [0.9213584889683054], m = [1.1448078545099563], lp = -152.06813143725682) | |
(s = [0.922586460156493], m = [1.1124397315282506], lp = -151.4646269775953) | |
(s = [1.0005355580905382], m = [0.9466019426759983], lp = -149.54373716403447) | |
(s = [1.2826795056371791], m = [0.8458701431333486], lp = -150.16138613045362) | |
(s = [0.981483126455662], m = [1.0794272367207176], lp = -150.35759914620317) | |
(s = [1.3454763130648564], m = [0.8590214936522462], lp = -150.3921587244739) | |
julia> # Or if you want a `MCMCChains.Chains` object instead: | |
chain = AbstractMCMC.bundle_samples(rng, model, spl, n_without_adapt, transitions, MCMCChains.Chains) | |
Chains MCMC chain (50×14×1 Array{Float64,3}): | |
Iterations = 1:50 | |
Thinning interval = 1 | |
Chains = 1 | |
Samples per chain = 50 | |
parameters = m, s | |
internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth | |
Summary Statistics | |
parameters mean std naive_se mcse ess rhat | |
Symbol Float64 Float64 Float64 Missing Float64 Float64 | |
m 0.9808 0.1319 0.0187 missing 39.1866 1.0475 | |
s 1.1444 0.2050 0.0290 missing 41.8236 1.0566 | |
Quantiles | |
parameters 2.5% 25.0% 50.0% 75.0% 97.5% | |
Symbol Float64 Float64 Float64 Float64 Float64 | |
m 0.7577 0.8732 1.0016 1.0748 1.1675 | |
s 0.8787 0.9564 1.1018 1.3033 1.5137 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment