Skip to content

Instantly share code, notes, and snippets.

@simonbyrne
Created October 15, 2012 19:48
Show Gist options
  • Save simonbyrne/3894849 to your computer and use it in GitHub Desktop.
Save simonbyrne/3894849 to your computer and use it in GitHub Desktop.
A bugs-type sampler in julia
require("extras/distributions.jl")
require("slicesampler.jl")
import Distributions.*
macro buildslicesampler(model)
# pull apart the model into variables and distributions
variables = Dict{Symbol,Expr}()
unboundvars = Set{Symbol}()
for line in model.args
if typeof(line) == Expr
variables[line.args[1]] = line.args[2]
if !isbound(line.args[1])
add(unboundvars,line.args[1])
end
end
end
# construct log-density expression
ldargs = Array(Any,1+length(variables))
ldargs[1] = symbol("+")
ldargs[2:] = [:(logpdf($d,$v)) for (v,d) in variables]
ld = Expr(symbol("call"),ldargs,Any)
# update each unbound variable
updateargs = Array(Any,length(unboundvars))
updateargs[:] = [:($v=slice_sampler($v,$v->$ld)) for v in unboundvars]
update = Expr(symbol("block"),updateargs,Any)
# uncomment these to see what variables the sampler is updating
#print("\n")
#print(update)
#print("\n")
esc(update)
end
N = 100000
xs = Array(Float64,N)
ys = Array(Float64,N)
# initial values of chain
let x = 0.2, y = 0.1
for i = 1:N
@buildslicesampler begin
x = Normal(0.,1.)
y = Normal(x,1.)
end
xs[i] = x
ys[i] = y
end
end
print("Unobserved y\n")
print("x:",(mean(xs),std(xs)),"\n")
print("y:",(mean(ys),std(ys)),"\n")
y = 0.5
let x = 0.2
for i = 1:N
@buildslicesampler begin
x = Normal(0.,1.)
y = Normal(x,1.)
end
xs[i] = x
ys[i] = y
end
end
print("Observed y=",y,"\n")
print("x:",(mean(xs),std(xs)),"\n")
print("y:",(mean(ys),std(ys)),"\n")
# Ported from Radford Neal's R code, with a few thinggs missing
# taken from https://github.com/doobwa/mcmc.jl/blob/master/src/slicesampler.jl
# Arguments:
#
# x0 Initial point
# g Function returning the log of the probability density (plus constant)
# w Size of the steps for creating interval (default 1)
# m Limit on steps
# lower Lower bound on support of the distribution (default -Inf)
# upper Upper bound on support of the distribution (default +Inf)
# gx0 g(x0)
#
function slice_sampler(x0::Float64, g::Function, w::Float64, m::Int64, lower::Float64, upper::Float64, gx0::Float64)
if w <= 0
error("Negative w not allowed")
end
if m <= 0
error("Limit on steps must be positive")
end
if upper < lower
error("Upper limit must be above lower limit")
end
# Determine the slice level, in log terms.
logy::Float64 = gx0 - rand(Exponential(1.0))
# Find the initial interval to sample from.
u::Float64 = rand() * w
L::Float64 = x0 - u
R::Float64 = x0 + (w-u)
# Expand the interval until its ends are outside the slice, or until
# the limit on steps is reached.
J::Int64 = floor(rand() * m)
K::Int64 = (m-1) - J
while J > 0
if L <= lower || g(L)::Float64 <= logy
break
end
L -= w
J -= 1
end
while K > 0
if R >= upper || g(R)::Float64 <= logy
break
end
R += w
K -= 1
end
# Shrink interval to lower and upper bounds.
L = L < lower ? lower : L
R = R > upper ? upper : R
# Sample from the interval, shrinking it on each rejection.
x1::Float64 = 0.0 # need to initialize it in this scope first
gx1::Float64 = 0.0
while true
x1 = rand() * (R-L) + L
gx1 = g(x1)::Float64
if gx1 >= logy
break
end
if x1 > x0
R = x1
else
L = x1
end
end
return x1
end
function slice_sampler(x0::Float64, g::Function, w::Float64, m::Int64, lower::Float64, upper::Float64)
gx0 = g(x0)
slice_sampler(x0,g,w,m,lower,upper,gx0)
end
function slice_sampler(x0::Float64, g::Function, gx0::Float64)
slice_sampler(x0,g,.5,10000,-Inf,Inf,gx0)
end
function slice_sampler(x0::Float64, g::Function)
slice_sampler(x0,g,.5,10000,-Inf,Inf)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment