Skip to content

Instantly share code, notes, and snippets.

@cpfiffer
Created March 3, 2020 14:19
Show Gist options
  • Save cpfiffer/11193ec190b7bf0a18aedad9e8b43d4f to your computer and use it in GitHub Desktop.
Save cpfiffer/11193ec190b7bf0a18aedad9e8b43d4f to your computer and use it in GitHub Desktop.
Stream Turing samples onto disk as they come.
using Turing, MCMCChains
using AbstractMCMC
using JLD2, FileIO
import Random: GLOBAL_RNG
# Create a model.
@model model(y) = begin
μ ~ Normal(0, 1)
s ~ InverseGamma(2,3)
for i in 1:length(y)
y[i] ~ Normal(μ, s)
end
end
# Generate synthetic data.
y = rand(Normal(0, 1), 10)
# Create a model and a sampler.
model = model(y)
# Note that a sampler must be explicitly constructed here --
# using spl = HMC(0.1, 7) will not work.
spl = Turing.Sampler(HMC(0.1, 7), model)
# Remove the samples file if it already exists.
isfile("samples.jld2") && rm("samples.jld2")
# Iteratively draw samples.
max_draws = 100
for (iteration, t) in enumerate(AbstractMCMC.steps!(model, spl))
# Store each sample t in a JLD file, indexed by the sample number.
jldopen("samples.jld2", "a+") do file
file["$iteration"] = t
end
# If we have more than max_draws samples, terminate sampling.
if iteration >= max_draws
break
end
end
# This function accepts a JLD file of samples and a range object,
# and returns the relevant samples stored in an MCMCChains object.
function restore_range(file, range)
samples = load(file)
string_range = map(x -> string(x), range)
ts = [samples[s] for s in string_range]
return AbstractMCMC.bundle_samples(GLOBAL_RNG, model, spl, length(ts), ts, Chains)
end
# Go grab the first ten samples.
chain = restore_range("samples.jld2", 1:100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment