Skip to content

Instantly share code, notes, and snippets.

@bearloga
Last active December 27, 2015 04:58
Show Gist options
  • Save bearloga/7270266 to your computer and use it in GitHub Desktop.
Save bearloga/7270266 to your computer and use it in GitHub Desktop.
par.trace.samples() runs independent MCMC chains in parallel on a multicore/multiCPU system. It breaks up a single long run into several smaller chunks and that allows it to report the simulation's progress. At the end of the run, all the chunks are combined into a single mcmc object. The mcmc objects from the chains are then combined into a sin…
# This file contains two functions:
# - combine.samples (used by par.trace.samples)
# - par.trace.samples
# and an example at the bottom.
## Author: Mikhail Popov (mikhail [at] mpopov.com)
# install.packages("rjags") # JAGS must be installed on system
# install.packages("doMC") # Unix only
## Usage:
# par.trace.samples(model.file,
# data.list,
# n.adapt,
# inits.function, # Should have 1 argument: chain number
# n.iter,
# thin,
# monitor.parms,
# n.breaks # Splits up the full run into smaller runs.
# )
#
## Returns: list object
# $runtime (in seconds)
# $samples (mcmc.list)
# $coef (last values of unobserved nodes, can be used as initial values)
## Run:
library(rjags); load.module("lecuyer")
library(doMC); registerDoMC(3)
# Results in 3 independent chains (1 per CPU core)
combine.samples <- function(x,y) {
# x: mcmc object
# y: mcmc object
if ( length(x) == 1 ) x <- x[[1]]
if ( length(y) == 1 ) y <- y[[1]]
z <- rbind(x,y)
class(z) <- 'mcmc'
attr(z,'dimnames') <- attr(x,'dimnames')
attr(z,'mcpar') <- c(attr(x,'mcpar')[1],attr(y,'mcpar')[2],attr(x,'mcpar')[3])
return(z)
}
par.trace.samples <- function(file,data,n.adapt=5e2,inits=NULL,
n.iter=1e3,thin=1,monitor,n.breaks=c(10,2,5)) {
# inits: function(i) where i will be from 1 to getDoParWorkers()
n.breaks <- as.numeric(match.arg(as.character(n.breaks),c("10","2","5")))
if ( is.null(inits) ) {
inits <- function(i) list(.RNG.name='lecuyer::RngStream',.RNG.seed=i)
} else if ( !is.function(inits) ) {
warning("inits must be a function"); return(NA)
}
x <- system.time({result <- foreach(i=1:getDoParWorkers()) %dopar% {
model <- jags.model(file=file,data=data,n.adapt=0,inits=inits(i),
n.chains=1,quiet=T)
update(model,n.iter=n.adapt,progress.bar="none")
cat("\nChain ",i," finished adapting.\n")
# Break up a single sampling call into multiple sampling calls.
samples <- list()
for ( j in 1:n.breaks ) {
samples[[j]] <- coda.samples(model,monitor,n.iter=n.iter/n.breaks,thin=thin,progress.bar="none")[[1]]
cat("\nChain ",i," is ",round(j/n.breaks*100,3),"% done.\n",sep="")
}
# Combine all the mini-samples into a single sample.
samples <- Reduce(combine.samples,samples)
# Return coefs (values of unobserved nodes) & samples.
return(list(coef=coef(model),samples=samples))
}})
# x : times ['user' 'system' 'elapsed'] in seconds
samples <- as.mcmc.list(lapply(result,function(x)x[['samples']]))
coef <- lapply(result,function(x)x[['coef']]) # Can be used as inits
list(runtime=x['elapsed'],samples=samples,coef=coef)
}
############################################################################
##################### Fake Data Test and Usage Example #####################
############################################################################
set.seed(0)
N <- 200
P <- 10
x <- round(matrix(rnorm(N*P),nrow=N,byrow=T),3)
b <- round(rnorm(P+1,0,runif(1,1,10)),3)
y <- round(as.vector(cbind(1,x) %*% b + rnorm(N,0,runif(1,1,10))),3)
Data <- list(N=N,x=cbind(1,x),y=y,P=P)
rm(x,y,N,P)
# Model
Model <- "model {
for ( i in 1:N ) {
y[i] ~ dnorm(mu[i],tau)
mu[i] <- inprod(beta[],x[i,])
}
for ( j in 1:(P+1) ) {
beta[j] ~ dnorm(0,1E-4)
}
sigma ~ dunif(0,1E2)
tau <- pow(sigma,-2)
}
"
f <- file("Model.txt","w")
writeLines(Model,f)
close(f); rm(f)
x <- par.trace.samples(file="Model.txt",data=Data,monitor=c('beta'),
n.iter=1e5,n.adapt=1e3,n.breaks=5,thin=1e2)
print(x$runtime)
summary(x$samples)
## Continuing the unsaved chains:
inits <- function(i) {
c(x$coef[[i]],list(.RNG.name='lecuyer::RngStream',.RNG.seed=i))
}
x <- par.trace.samples(file="Model.txt",data=Data,monitor=c('beta'),
inits=inits,n.iter=1e5,n.adapt=1e3,n.breaks=5,thin=1e2)
cat(sprintf("The process took %.3f hours to complete.\n",x$runtime/3600))
summary(x$samples)
## Cleanup
rm(x,inits,Data,Model,b)
unlink("Model.txt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment