Forked from bearloga/Parallel RJags with Progress Monitoring.R
Created
November 1, 2013 19:31
-
-
Save araastat/7270662 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This file contains two functions: | |
# - combine.samples (used by par.trace.samples) | |
# - par.trace.samples | |
# and an example at the bottom. | |
## Author: Mikhail Popov (mikhail [at] mpopov.com) | |
# install.packages("rjags") # JAGS must be installed on system | |
# install.packages("doMC") # Unix only | |
## Usage: | |
# par.trace.samples(model.file, | |
# data.list, | |
# n.adapt, | |
# inits.function, # Should have 1 argument: chain number | |
# n.iter, | |
# thin, | |
# monitor.parms, | |
# n.breaks # Splits up the full run into smaller runs. | |
# ) | |
# | |
## Returns: list object | |
# $runtime (in seconds) | |
# $samples (mcmc.list) | |
# $coef (last values of unobserved nodes, can be used as initial values) | |
## Run: | |
library(rjags); load.module("lecuyer") | |
library(doMC); registerDoMC(3) | |
# Results in 3 independent chains (1 per CPU core) | |
combine.samples <- function(x,y) { | |
# x: mcmc object | |
# y: mcmc object | |
if ( length(x) == 1 ) x <- x[[1]] | |
if ( length(y) == 1 ) y <- y[[1]] | |
z <- rbind(x,y) | |
class(z) <- 'mcmc' | |
attr(z,'dimnames') <- attr(x,'dimnames') | |
attr(z,'mcpar') <- c(attr(x,'mcpar')[1],attr(y,'mcpar')[2],attr(x,'mcpar')[3]) | |
return(z) | |
} | |
par.trace.samples <- function(file,data,n.adapt=5e2,inits=NULL, | |
n.iter=1e3,thin=1,monitor,n.breaks=c(10,2,5)) { | |
# inits: function(i) where i will be from 1 to getDoParWorkers() | |
n.breaks <- as.numeric(match.arg(as.character(n.breaks),c("10","2","5"))) | |
if ( is.null(inits) ) { | |
inits <- function(i) list(.RNG.name='lecuyer::RngStream',.RNG.seed=i) | |
} else if ( !is.function(inits) ) { | |
warning("inits must be a function"); return(NA) | |
} | |
x <- system.time({result <- foreach(i=1:getDoParWorkers()) %dopar% { | |
model <- jags.model(file=file,data=data,n.adapt=0,inits=inits(i), | |
n.chains=1,quiet=T) | |
update(model,n.iter=n.adapt,progress.bar="none") | |
cat("\nChain ",i," finished adapting.\n") | |
# Break up a single sampling call into multiple sampling calls. | |
samples <- list() | |
for ( j in 1:n.breaks ) { | |
samples[[j]] <- coda.samples(model,monitor,n.iter=n.iter/n.breaks,thin=thin,progress.bar="none")[[1]] | |
cat("\nChain ",i," is ",round(j/n.breaks*100,3),"% done.\n",sep="") | |
} | |
# Combine all the mini-samples into a single sample. | |
samples <- Reduce(combine.samples,samples) | |
# Return coefs (values of unobserved nodes) & samples. | |
return(list(coef=coef(model),samples=samples)) | |
}}) | |
# x : times ['user' 'system' 'elapsed'] in seconds | |
samples <- as.mcmc.list(lapply(result,function(x)x[['samples']])) | |
coef <- lapply(result,function(x)x[['coef']]) # Can be used as inits | |
list(runtime=x['elapsed'],samples=samples,coef=coef) | |
} | |
############################################################################ | |
##################### Fake Data Test and Usage Example ##################### | |
############################################################################ | |
set.seed(0) | |
N <- 200 | |
P <- 10 | |
x <- round(matrix(rnorm(N*P),nrow=N,byrow=T),3) | |
b <- round(rnorm(P+1,0,runif(1,1,10)),3) | |
y <- round(as.vector(cbind(1,x) %*% b + rnorm(N,0,runif(1,1,10))),3) | |
Data <- list(N=N,x=cbind(1,x),y=y,P=P) | |
rm(x,y,N,P) | |
# Model | |
Model <- "model { | |
for ( i in 1:N ) { | |
y[i] ~ dnorm(mu[i],tau) | |
mu[i] <- inprod(beta[],x[i,]) | |
} | |
for ( j in 1:(P+1) ) { | |
beta[j] ~ dnorm(0,1E-4) | |
} | |
sigma ~ dunif(0,1E2) | |
tau <- pow(sigma,-2) | |
} | |
" | |
f <- file("Model.txt","w") | |
writeLines(Model,f) | |
close(f); rm(f) | |
x <- par.trace.samples(file="Model.txt",data=Data,monitor=c('beta'), | |
n.iter=1e5,n.adapt=1e3,n.breaks=5,thin=1e2) | |
print(x$runtime) | |
summary(x$samples) | |
## Continuing the unsaved chains: | |
inits <- function(i) { | |
list(x$coef[[i]],.RNG.name='lecuyer::RngStream',.RNG.seed=i) | |
} | |
x <- par.trace.samples(file="Model.txt",data=Data,monitor=c('beta'), | |
inits=inits(1),n.iter=1e5,n.adapt=1e3,n.breaks=5,thin=1e2) | |
cat(sprintf("The process took %.3f hours to complete.\n",x$runtime/3600)) | |
summary(x$samples) | |
## Cleanup | |
rm(x,inits,Data,Model,b) | |
unlink("Model.txt") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment