Last active
April 17, 2019 22:46
-
-
Save k-barton/85068ed689a7d5012713bd45651ce30e to your computer and use it in GitHub Desktop.
JAGS/BUGS helper functions (define model as R code, execute JAGS code in R /simulate/, etc.)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### BUGS/JAGS helper functions | |
### Example usage: | |
# | |
## Define JAGS model (currently only `model` section | |
#mocode <- bugsmodel({ | |
# alpha ~ dunif(0, 1) | |
# for(i in 1:10) { | |
# beta[i] ~ dunif(0, 1) | |
# } | |
# }, | |
# var = list(a, b[nn, mm]), | |
# data = { foo <- 10; bar <- 1:10 } | |
#) | |
# | |
#as.character(mocode) # get model code | |
# | |
## Run BUGS/JAGS code in R (i.e. simulate). | |
# 'data' may provide all/some of the used variables | |
# e <- run(mocode, data = list()) | |
# print(e) | |
# | |
## Example 2: | |
### use `bugs` wrapper for use with 'rjags' | |
# | |
#mocode <- bugsmodel({ | |
# mu ~ dunif(-10, 10) | |
# sigma ~ dunif(0, 100) | |
# for(i in 1:N) { | |
# x[i] ~ dnorm(mu, sigma) | |
# } | |
# }) | |
# | |
## simulate: | |
#run(mocode, data = list( | |
# x = rnorm(10, mean = 6.66, sd = 1.1), N = 10, | |
# mu = 0, sigma = .1 | |
#)) | |
# | |
# | |
#mo <- | |
#bugs(data = list( | |
# x = rnorm(10, mean = 6.66, sd = 1.1), N = 10 | |
# ), inits = list( | |
# mu = 0, sigma = .1 | |
# ), parameters = c("mu", "sigma"), | |
# mocode, n.chains = 1, n.thin = 10, n.iter = 1000, n.burnin = 100) | |
# | |
#summary(mo$samples) | |
# | |
#=============================================================================== | |
.jagsenv <- | |
as.environment(list( | |
"[<-" = function(x, ..., value) { # subscripted assignment with dynamic resize | |
d <- sapply(list(...), max, na.rm = TRUE) | |
g <- dim(x) | |
if(any(g == 0)) { # new size from 0 | |
x <- array(NA_real_, dim = d) | |
#message("resized from 0 to ", paste(d, collapse = " "), "\n") | |
} else if(any(d > g)) { # resize | |
d <- pmax(d, g) | |
nx <- array(NA_real_, dim = d) | |
i <- do.call(base::"[", c(list(array(seq_along(nx), dim = d)), lapply(g, seq.int))) | |
nx[i] <- x | |
x <- nx | |
#message("resized from ", paste(g, collapse = " "), " to ", paste(d, collapse = " "), "\n") | |
} | |
base::`[<-`(x, ..., value) | |
}, | |
# categorical distribution random function | |
rcat = function(n, p) sample.int(length(p), n, replace = TRUE, prob = p), | |
# Bernoulli distribution | |
rbern = function(n, p) stats::rbinom(n, 1L, p), | |
dbern = function(x, p) stats::dbinom(x, 1L, p), | |
pbern = function(q, p) stats::pbinom(q, 1L, p), | |
qbern = function(x, p) stats::qbinom(x, 1L, p), | |
dchisqr = function (x, df) dchisq(x, df), | |
pchisqr = function (q, df) pchisq(q, df), | |
qchisqr = function (p, df) qchisq(p, df), | |
logit = function(x) stats::qlogis(x), | |
ilogit = function(x) stats::plogis(x), | |
"logit<-" = function(x, value) stats::plogis(value), | |
"log<-" = function(x, value) exp(value), | |
cloglog = function(x) log(-log(1 - x)), | |
icloglog = function(x) pmax(pmin(-expm1(-exp(x)), 1 - .Machine$double.eps), .Machine$double.eps), | |
"cloglog<-" = function(x, value) pmax(pmin(-expm1(-exp(value)), 1 - .Machine$double.eps), .Machine$double.eps), | |
probit = function(x) stats::qnorm(x), | |
phi = function(x) stats::pnorm(x), | |
"probit<-" = function(x, value) stats::pnorm(value), | |
equals = identical, | |
arccos = acos, | |
arccosh = acosh, | |
arcsin = asin, | |
arcsinh = asinh, | |
logfact = lfactorial, | |
loggam = lgamma, | |
step = function(x) x >= 0, | |
pow = function(x, z) x^z, | |
inverse = function(x) solve(x), | |
logdet = function(m) determinant(m, logarithm = TRUE)$modulus[[1L]], | |
inprod = function(x1, x2) x1 %*% x2, | |
interp.lin = function(e, v1, v2) approx(v1, v2, e)$y | |
)) | |
parent.env(.jagsenv) <- as.environment("package:stats") | |
exprapply <- | |
function(expr, what, FUN, ..., symbols = FALSE, parent = NULL) { | |
self <- sys.function() | |
if ((ispairlist <- is.pairlist(expr)) || is.expression(expr)) { | |
for (i in seq_along(expr)) expr[i] <- list(self(expr[[i]], | |
what, FUN, ..., | |
symbols = symbols, parent = expr | |
)) | |
return(if (ispairlist) as.pairlist(expr) else expr) | |
} | |
n <- length(expr) | |
if (n == 0L) { | |
return(expr) | |
} else if (n == 1L) { | |
if (!is.call(expr)) { | |
if (symbols && (anyNA(what) || any(expr == what))) { | |
expr <- FUN(expr, ..., parent = parent) | |
} | |
return(expr) | |
} | |
} | |
else { | |
if (expr[[1L]] == "function") { | |
if (n == 4L) { | |
n <- 3L | |
expr[[4L]] <- NULL | |
} | |
} | |
for (i in seq.int(2L, n)) { | |
y <- self(expr[[i]], what, FUN, ..., | |
symbols = symbols, | |
parent = expr | |
) | |
if (!missing(y)) { | |
expr[i] <- list(y) | |
} | |
} | |
} | |
if (anyNA(what) || (length(expr[[1L]]) == 1L && any(expr[[1L]] == | |
what))) { | |
expr <- FUN(expr, ..., parent = parent) | |
} | |
expr | |
} | |
replaceStochasticNodes <- | |
function(expr) { | |
exprapply(expr, "~", function(x, ...) { | |
x[[1L]] <- as.name("<-") | |
rhs <- x[[3L]] | |
f <- as.character(rhs[[1L]]) | |
substr(f, 1L, 1L) <- "r" # assume r<distr> counterpart exists, if not it should be defined | |
rhs[[1L]] <- as.name(f) | |
n <- length(rhs) | |
rhs <- rhs[c(1L, NA, if (n > 1L) seq.int(2L, n))] | |
rhs[[2L]] <- 1 | |
rhs -> x[[3L]] | |
x | |
}) | |
} | |
createVars <- | |
function(expr) { | |
vars <- integer(0L) | |
exprapply(expr, c("<-", "~"), function(x, ...) { | |
v <- x[[2L]] | |
if(is.call(v) && v[[1L]] == "[") { | |
d <- length(v) - 2L | |
v <- v[[2L]] | |
} else d <- 1L | |
a <- deparse(v, control = NULL) | |
vars[a] <<- max(vars[a], d, na.rm = TRUE) | |
x | |
}) | |
vars <- lapply(vars, function(d) | |
if(d == 1) numeric(0L) else array(numeric(0L), dim = rep(0, d)) | |
) | |
cat("Variables defined: \n") | |
cat(paste0("* ", names(vars), sapply(vars, function(x) if(is.array(x)) | |
if(is.matrix(x)) " = matrix" else | |
paste0(" = array<", length(dim(x)), ">") | |
else "")), sep = "\n") | |
invisible(vars) | |
} | |
bugsmodel <- | |
function(model, var, data) { | |
cl <- match.call() | |
has.var <- !missing(var) && is.call(cl$var) | |
has.data <- !missing(data) | |
.asstr <- function(a, x) | |
gsub(" *%% *(?=[IT] *\\()", " ", | |
paste(a, paste0(deparse(x, control = NULL), collapse = "\n")), | |
perl = TRUE) | |
rval <- expression() | |
text <- character(1L + has.data + has.var) | |
if (has.var) { | |
var <- cl$var | |
n <- length(var) | |
var[[1L]] <- as.symbol("expression") | |
n <- length(var) | |
a <- character(n - 1L) | |
for (i in 2L:n) | |
a[i - 1L] <- deparse(var[[i]], control = NULL, nlines = 1L) | |
text[1L] <- paste0("var ", paste0(a, collapse = ", "), ";") | |
rval$var <- var | |
} | |
if (has.data) { | |
data <- cl$data | |
if (data[[1L]] != "{") data <- call("{", data) | |
text[1L + has.var] <- .asstr("data", data) | |
rval$data <- data | |
} | |
model <- cl$model | |
if (model[[1L]] != "{") model <- call("{", model) | |
text[1L + has.data + has.var] <- .asstr("model", model) | |
rval$model <- model | |
attr(rval, "text") <- paste0(text, collapse = "\n") | |
class(rval) <- "bugscode" | |
rval | |
} | |
as.character.bugscode <- | |
function(x, ...) attr(x, "text") | |
print.bugscode <- | |
function(x, ...) { | |
cat(as.character(x)) | |
invisible(x) | |
} | |
run <- | |
function(model, data, envir = .jagsenv) { | |
dataenv <- as.environment(data) | |
parent.env(dataenv) <- envir | |
vars <- createVars(model$model) # TODO: $data | |
vars <- vars[!(names(vars) %in% names(dataenv))] | |
for(a in names(vars)) assign(a, vars[[a]], dataenv) | |
if(!is.null(model$data)) | |
eval(replaceStochasticNodes(model$data), dataenv) | |
eval(replaceStochasticNodes(model$model), dataenv) | |
as.list(dataenv) | |
} | |
jags.model2 <- function(file, data = NULL, inits, n.chains = 1, n.adapt = 1000, | |
quiet = FALSE) { | |
if (inherits(file, "bugscode") || (is.character(file) && !file.exists(file))) { | |
file <- textConnection(as.character(file)) | |
on.exit(close(file)) | |
} | |
rjags::jags.model(file, data, inits, n.chains, n.adapt, quiet) | |
} | |
bugs <- | |
function(data, inits, parameters, file, n.chains, n.thin, n.iter, n.burnin, | |
debug = FALSE, ...) { | |
if (!is.character(file) || length(file) > 1L || !file.exists(file)) { | |
file <- textConnection(as.character(file)) | |
on.exit(close(file)) | |
} | |
model <- rjags::jags.model( | |
data = data, inits = inits, | |
file = file, n.chains = n.chains, n.adapt = n.burnin | |
) | |
if(!missing(n.iter) && is.numeric(n.iter) && n.iter > 0) { | |
samp <- rjags::coda.samples(model, n.iter = n.iter, thin = n.thin, | |
variable.names = parameters) | |
list(samples = samp, model = model) | |
} else model | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example usage
Define JAGS model ('model' is required, 'data' and 'var' are optional)
model
anddata
is JAGS code passed as R expressions (i.e. plain code, not a quoted string).var
is alist
of variable names optionally followed by brackets with dimensions.Run BUGS/JAGS code in R (i.e. simulate data).
data
argument may provide all/some of the used variables. If missing, they will be automatically defined.Example: Using
bugs
wrapper for use with 'rjags'Simulate with R:
Note that the code must follow R logic, not the typical backwards BUGS order, e.g.:
wrong:
mu
is defined after it has been usedcorrect:
Sample with JAGS: