Skip to content

Instantly share code, notes, and snippets.

@derekpowell
Last active May 22, 2023 20:36
Show Gist options
  • Save derekpowell/f1994c0f8325abbc5d300600744af39d to your computer and use it in GitHub Desktop.
Save derekpowell/f1994c0f8325abbc5d300600744af39d to your computer and use it in GitHub Desktop.
Wrapper for brm() that supports caching of BRMS models
cbrm <- function(formula,
data,
family = gaussian(),
prior = NULL,
autocor = NULL,
cov_ranef = NULL,
sample_prior = c("no", "yes", "only"),
sparse = FALSE,
knots = NULL,
stan_funs = NULL,
fit = NA,
save_ranef = TRUE,
save_mevars = FALSE,
save_all_pars = FALSE,
inits = "random",
chains = 4,
iter = 2000,
warmup = floor(iter/2),
thin = 1,
cores = getOption("mc.cores", 1L),
control = NULL,
algorithm = c("sampling", "meanfield", "fullrank"),
future = getOption("future", FALSE),
silent = TRUE,
seed = NA,
save_model = NULL,
save_dso = TRUE,
file = NULL
) {
if (is.null(control)) {
control <- list(adapt_delta = .80)
}
model_exists <- FALSE
# look for existing cached model
if (!(is.null(file))) {
model_dir <- dirname(file)
cached_file <- basename(file)
full_filename <- file
# check inputs of requested model
new_model_code <- brms::make_stancode(
formula = formula,
family = family,
prior = prior,
data = data,
sample_prior = sample_prior
)
new_model_data <- brms::make_standata(
formula = formula,
family = family,
prior = prior,
data = data,
sample_prior = sample_prior
)
# check for match between inputs of loaded and requested models
if (cached_file %in% list.files(model_dir)) {
loaded_model <- readRDS(full_filename)
code_match <- identical(new_model_code, brms::stancode(loaded_model))
data_match <- identical(new_model_data, brms::standata(loaded_model))
control_match <- identical(brms::control_params(loaded_model, pars=names(control)),
control)
loaded_samples <- posterior_samples(loaded_model, add_chain=TRUE)
iter_match <- identical(iter, max(loaded_samples$iter))
warmup_match <- identical(warmup, min(loaded_samples$iter)-1)
matching_tests <- c(code_match, data_match, control_match, iter_match, warmup_match)
if (
(all(matching_tests))
) {
model_exists <- TRUE
}
}
}
if (model_exists) {
output <- loaded_model
} else {
output <- brms::brm(
formula = formula,
data = data,
family = family,
prior = prior,
autocor = autocor,
cov_ranef = cov_ranef,
sample_prior = sample_prior,
sparse = sparse,
knots = knots,
stan_funs = stan_funs,
fit = fit,
save_ranef = TRUE,
save_mevars = save_mevars,
save_all_pars = save_all_pars,
inits = inits,
chains = chains,
iter = iter,
warmup = warmup,
thin = thin,
cores = cores,
control = control,
algorithm = algorithm,
future = future,
silent = silent,
seed = seed,
save_model = save_model,
save_dso = save_dso
)
if (!(is.null(cached_file))) {
saveRDS(output, file = full_filename)
}
}
return(output)
}
@derekpowell
Copy link
Author

derekpowell commented Apr 10, 2018

A drop-in replacement for brm() that fits and caches BRMS models.

Developed for BRMS v2.2.0 or later

BRMS is awesome but sampling models can sometimes be very time consuming. The latest version of BRMS includes an optional file argument to save and load models, but this can interrupt reproducibility (having to comment/uncomment code or delete files, potentially missing meaningful changes in other areas, etc). This function intelligently caches the results of BRMS models and only engages in time-intensive resampling when model specification or inputs change. In contrast to the default behavior of brm() with a file argument, it does not load the saved file if the model specification or input data have changed.

When run, function checks for existing model with matching filename, and compares the inputs to the existing model with the current inputs to the function to determine if the correct model has already been computed. If so, it returns the cached model. If not, it reruns the model and caches the new result. Just as with brm(), file caching is managed with the optional file argument. If file = NULL, the caching steps are skipped.

To load up the function:

devtools::source_gist(id = "f1994c0f8325abbc5d300600744af39d", filename="cbrm.R")

Notes and caveats

  1. Cached results are used whenever the input and cached stan model code, stan data, control, iter, and warmup arguments all match. Other changes might not be detected, so care is needed. On the other hand, non-visible changes might cause models to rerun, such as changes in brms version.
  2. cbrm() and brm() are now entirely interchangeable, save that cbrm() overrides the default behavior of the file argument.

Todo list

  1. Add tryCatch() logic to catch failing cache file writes and still return model

Examples and tests

library(brms)
data("iris")

devtools::source_gist(id = "f1994c0f8325abbc5d300600744af39d", filename="cbrm.R")

# reference model
my_simple_model <- cbrm(
  Sepal.Width ~ Sepal.Length,
  data = iris[1:100, ],
  file = "my_simple_model.rds",
)

# should reload cached model
my_simple_model <- cbrm(
  Sepal.Width ~ Sepal.Length,
  data = iris[1:100, ],
  file = "my_simple_model.rds",
)

# should rerun model with changed formula, prior, etc.
my_simple_model <- cbrm(
  Sepal.Width ~ Sepal.Length + Species,
  prior = set_prior("normal(0,1)", class="b"),
  data = iris[1:100, ],
  file = "my_simple_model.rds",
)

# should rerun model with changed data
my_simple_model <- cbrm(
  Sepal.Width ~ Sepal.Length,
  data = iris[5:104, ],
  file = "my_simple_model.rds",
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment