Created
October 15, 2012 19:48
-
-
Save simonbyrne/3894849 to your computer and use it in GitHub Desktop.
A bugs-type sampler in julia
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require("extras/distributions.jl") | |
require("slicesampler.jl") | |
import Distributions.* | |
macro buildslicesampler(model) | |
# pull apart the model into variables and distributions | |
variables = Dict{Symbol,Expr}() | |
unboundvars = Set{Symbol}() | |
for line in model.args | |
if typeof(line) == Expr | |
variables[line.args[1]] = line.args[2] | |
if !isbound(line.args[1]) | |
add(unboundvars,line.args[1]) | |
end | |
end | |
end | |
# construct log-density expression | |
ldargs = Array(Any,1+length(variables)) | |
ldargs[1] = symbol("+") | |
ldargs[2:] = [:(logpdf($d,$v)) for (v,d) in variables] | |
ld = Expr(symbol("call"),ldargs,Any) | |
# update each unbound variable | |
updateargs = Array(Any,length(unboundvars)) | |
updateargs[:] = [:($v=slice_sampler($v,$v->$ld)) for v in unboundvars] | |
update = Expr(symbol("block"),updateargs,Any) | |
# uncomment these to see what variables the sampler is updating | |
#print("\n") | |
#print(update) | |
#print("\n") | |
esc(update) | |
end | |
N = 100000 | |
xs = Array(Float64,N) | |
ys = Array(Float64,N) | |
# initial values of chain | |
let x = 0.2, y = 0.1 | |
for i = 1:N | |
@buildslicesampler begin | |
x = Normal(0.,1.) | |
y = Normal(x,1.) | |
end | |
xs[i] = x | |
ys[i] = y | |
end | |
end | |
print("Unobserved y\n") | |
print("x:",(mean(xs),std(xs)),"\n") | |
print("y:",(mean(ys),std(ys)),"\n") | |
y = 0.5 | |
let x = 0.2 | |
for i = 1:N | |
@buildslicesampler begin | |
x = Normal(0.,1.) | |
y = Normal(x,1.) | |
end | |
xs[i] = x | |
ys[i] = y | |
end | |
end | |
print("Observed y=",y,"\n") | |
print("x:",(mean(xs),std(xs)),"\n") | |
print("y:",(mean(ys),std(ys)),"\n") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Ported from Radford Neal's R code, with a few thinggs missing | |
# taken from https://github.com/doobwa/mcmc.jl/blob/master/src/slicesampler.jl | |
# Arguments: | |
# | |
# x0 Initial point | |
# g Function returning the log of the probability density (plus constant) | |
# w Size of the steps for creating interval (default 1) | |
# m Limit on steps | |
# lower Lower bound on support of the distribution (default -Inf) | |
# upper Upper bound on support of the distribution (default +Inf) | |
# gx0 g(x0) | |
# | |
function slice_sampler(x0::Float64, g::Function, w::Float64, m::Int64, lower::Float64, upper::Float64, gx0::Float64) | |
if w <= 0 | |
error("Negative w not allowed") | |
end | |
if m <= 0 | |
error("Limit on steps must be positive") | |
end | |
if upper < lower | |
error("Upper limit must be above lower limit") | |
end | |
# Determine the slice level, in log terms. | |
logy::Float64 = gx0 - rand(Exponential(1.0)) | |
# Find the initial interval to sample from. | |
u::Float64 = rand() * w | |
L::Float64 = x0 - u | |
R::Float64 = x0 + (w-u) | |
# Expand the interval until its ends are outside the slice, or until | |
# the limit on steps is reached. | |
J::Int64 = floor(rand() * m) | |
K::Int64 = (m-1) - J | |
while J > 0 | |
if L <= lower || g(L)::Float64 <= logy | |
break | |
end | |
L -= w | |
J -= 1 | |
end | |
while K > 0 | |
if R >= upper || g(R)::Float64 <= logy | |
break | |
end | |
R += w | |
K -= 1 | |
end | |
# Shrink interval to lower and upper bounds. | |
L = L < lower ? lower : L | |
R = R > upper ? upper : R | |
# Sample from the interval, shrinking it on each rejection. | |
x1::Float64 = 0.0 # need to initialize it in this scope first | |
gx1::Float64 = 0.0 | |
while true | |
x1 = rand() * (R-L) + L | |
gx1 = g(x1)::Float64 | |
if gx1 >= logy | |
break | |
end | |
if x1 > x0 | |
R = x1 | |
else | |
L = x1 | |
end | |
end | |
return x1 | |
end | |
function slice_sampler(x0::Float64, g::Function, w::Float64, m::Int64, lower::Float64, upper::Float64) | |
gx0 = g(x0) | |
slice_sampler(x0,g,w,m,lower,upper,gx0) | |
end | |
function slice_sampler(x0::Float64, g::Function, gx0::Float64) | |
slice_sampler(x0,g,.5,10000,-Inf,Inf,gx0) | |
end | |
function slice_sampler(x0::Float64, g::Function) | |
slice_sampler(x0,g,.5,10000,-Inf,Inf) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment