Last active
September 2, 2023 11:31
-
-
Save torfjelde/eac4919a271f4f7109633e55040d8377 to your computer and use it in GitHub Desktop.
Rough example of using StanDistributions.jl within Turing.jl.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
julia> using PosteriorDB, StanDistributions, Turing, BridgeStan, LinearAlgebra | |
julia> # Necessary overloads to make it work with Turing. | |
function DynamicPPL.init(rng, dist::StanDistribution, ::DynamicPPL.SampleFromPrior) | |
# `init` uses `rand` by default but this is not supported for `StanDistribution`. | |
return BridgeStan.param_constrain(dist.model, randn(rng, length(dist))) | |
end | |
julia> function DynamicPPL.with_logabsdet_jacobian_and_reconstruct(f, dist::StanDistribution, x) | |
# HACK: This is cheating. | |
return (f(x), zero(eltype(x))) | |
end | |
julia> function Bijectors.logpdf_with_trans(dist::StanDistribution, x::AbstractVector, istrans::Bool) | |
istrans || return logpdf(dist, x) | |
# Because `x` is always constrained. Yeah I know, I hate it too. | |
y = Bijectors.transform(bijector(dist), x) | |
return logpdf(StanDistributions._unconstrain(dist), y) | |
end | |
julia> # Set up Stan distribution. | |
pdb = PosteriorDB.database(); | |
julia> post = PosteriorDB.posterior(pdb, "eight_schools-eight_schools_noncentered"); | |
julia> stan_file = cp(PosteriorDB.path(PosteriorDB.implementation(PosteriorDB.model(post), "stan")), "stan_file.stan", force=true) | |
"stan_file.stan" | |
julia> stan_data = PosteriorDB.load(PosteriorDB.dataset(post), String) | |
"{\n \"J\": 8,\n \"y\": [28, 8, -3, 7, -1, 1, 18, 12],\n \"sigma\": [15, 10, 16, 11, 9, 11, 10, 18]\n}\n" | |
julia> dist = StanDistribution(stan_file, stan_data); | |
┌ Warning: Loading a shared object '/drive-2/Projects/public/DynamicPPL.jl/stan_file_model.so' which is already loaded. | |
│ If the file has changed since the last time it was loaded, this load may not update the library! | |
└ @ BridgeStan ~/.julia/packages/BridgeStan/r4RPQ/src/model.jl:51 | |
julia> # Create model. | |
@model demo() = x ~ dist | |
demo (generic function with 2 methods) | |
julia> model = demo() | |
DynamicPPL.Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}(demo, NamedTuple(), NamedTuple(), DynamicPPL.DefaultContext()) | |
julia> # Sample. | |
chain = sample(model, MH(Diagonal(ones(length(dist)))), 1000) | |
Sampling 100%|██████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00 | |
Chains MCMC chain (1000×11×1 Array{Float64, 3}): | |
Iterations = 1:1:1000 | |
Number of chains = 1 | |
Samples per chain = 1000 | |
Wall duration = 0.28 seconds | |
Compute duration = 0.28 seconds | |
parameters = x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10] | |
internals = lp | |
Summary Statistics | |
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec | |
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64 | |
x[1] 0.3018 0.8761 0.2046 19.6808 54.0652 1.1257 69.2986 | |
x[2] -0.0315 0.7574 0.1268 33.8723 38.8298 1.0469 119.2686 | |
x[3] -0.4987 0.9294 0.1768 29.4710 31.3870 1.0086 103.7712 | |
x[4] -0.1405 0.7941 0.1713 20.0057 54.8830 1.0954 70.4428 | |
x[5] -0.3183 0.6575 0.0991 45.9851 56.2364 1.0115 161.9194 | |
x[6] 0.0989 0.8708 0.1710 25.7288 33.8405 1.0148 90.5943 | |
x[7] 0.6967 0.9422 0.1373 46.4741 62.0100 1.0055 163.6412 | |
x[8] 0.1594 0.8737 0.1508 33.1242 79.3680 1.0323 116.6344 | |
x[9] 4.1155 2.1446 0.9550 5.3140 24.7503 1.1598 18.7112 | |
x[10] 4.3723 3.0414 0.5701 17.9901 45.3636 1.0880 63.3455 | |
Quantiles | |
parameters 2.5% 25.0% 50.0% 75.0% 97.5% | |
Symbol Float64 Float64 Float64 Float64 Float64 | |
x[1] -1.5304 -0.2382 0.3018 0.9397 1.9190 | |
x[2] -1.6546 -0.4930 -0.0070 0.4749 1.2945 | |
x[3] -1.9595 -1.0840 -0.5131 -0.0810 1.5832 | |
x[4] -2.1416 -0.5176 -0.1250 0.3843 1.1761 | |
x[5] -1.3227 -0.8960 -0.3740 0.1333 1.0498 | |
x[6] -1.2090 -0.5686 0.0243 0.8413 1.7140 | |
x[7] -1.2315 0.2178 0.6879 1.3352 2.4447 | |
x[8] -1.3547 -0.5421 0.3318 0.7988 1.6284 | |
x[9] 0.2154 2.5220 4.2424 5.8616 7.9825 | |
x[10] 0.4093 2.2244 3.8886 5.9015 12.4081 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Manifest.toml