Created
March 3, 2020 14:19
-
-
Save cpfiffer/11193ec190b7bf0a18aedad9e8b43d4f to your computer and use it in GitHub Desktop.
Stream Turing samples onto disk as they come.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Turing, MCMCChains | |
using AbstractMCMC | |
using JLD2, FileIO | |
import Random: GLOBAL_RNG | |
# Create a model. | |
@model model(y) = begin | |
μ ~ Normal(0, 1) | |
s ~ InverseGamma(2,3) | |
for i in 1:length(y) | |
y[i] ~ Normal(μ, s) | |
end | |
end | |
| |
# Generate synthetic data. | |
y = rand(Normal(0, 1), 10) | |
# Create a model and a sampler. | |
model = model(y) | |
# Note that a sampler must be explicitly constructed here -- | |
# using spl = HMC(0.1, 7) will not work. | |
spl = Turing.Sampler(HMC(0.1, 7), model) | |
# Remove the samples file if it already exists. | |
isfile("samples.jld2") && rm("samples.jld2") | |
# Iteratively draw samples. | |
max_draws = 100 | |
for (iteration, t) in enumerate(AbstractMCMC.steps!(model, spl)) | |
# Store each sample t in a JLD file, indexed by the sample number. | |
jldopen("samples.jld2", "a+") do file | |
file["$iteration"] = t | |
end | |
# If we have more than max_draws samples, terminate sampling. | |
if iteration >= max_draws | |
break | |
end | |
end | |
# This function accepts a JLD file of samples and a range object, | |
# and returns the relevant samples stored in an MCMCChains object. | |
function restore_range(file, range) | |
samples = load(file) | |
string_range = map(x -> string(x), range) | |
ts = [samples[s] for s in string_range] | |
return AbstractMCMC.bundle_samples(GLOBAL_RNG, model, spl, length(ts), ts, Chains) | |
end | |
# Go grab the first ten samples. | |
chain = restore_range("samples.jld2", 1:100) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment