Skip to content

Instantly share code, notes, and snippets.

@rubenarslan
Last active July 27, 2025 20:12
Show Gist options
  • Save rubenarslan/3d2c6a8c6064baaeabc9dc218d0c4928 to your computer and use it in GitHub Desktop.
Save rubenarslan/3d2c6a8c6064baaeabc9dc218d0c4928 to your computer and use it in GitHub Desktop.
thin brmsfit (brms fit objects)
#' Thin a brmsfit object post-hoc
#'
#' This function thins a brmsfit object after fitting. It checks the backend
#' used to fit the model ('rstan' or 'cmdstanr') and applies the correct
#' thinning method, as they store samples in different internal structures.
#'
#' @param fit A brmsfit object.
#' @param thin_by An integer factor by which to thin the post-warmup chains.
#' @return A new, thinned brmsfit object.
thin_brmsfit <- function(fit, thin_by) {
if (!requireNamespace("brms", quietly = TRUE)) {
stop("The 'brms' package is required.")
}
if (!inherits(fit, "brmsfit")) {
stop("Input must be a brmsfit object.")
}
if (!is.numeric(thin_by) || thin_by <= 1 || thin_by %% 1 != 0) {
stop("thin_by must be an integer greater than 1.")
}
backend <- fit$backend
if (!backend %in% c("rstan", "cmdstanr")) {
stop("This function only supports 'rstan' and 'cmdstanr' backends.", call. = FALSE)
}
fit_new <- fit
# --- 1. Determine thinning indices ---
sf_sim <- fit_new$fit@sim
# --- THE KEY CORRECTION ---
# The true number of post-warmup draws is in n_save, not calculated
# from iter and warmup. This handles pre-thinned models correctly.
n_draws_available <- sf_sim$n_save[1]
if (thin_by > n_draws_available) {
stop("thin_by factor is larger than the number of available post-warmup samples.")
}
keep_idx <- seq(from = 1, to = n_draws_available, by = thin_by)
n_draws_new <- length(keep_idx)
# --- 2. Apply thinning based on backend-specific structure ---
if (backend == "rstan") {
# rstan backend stores samples as a list of lists of vectors.
# The 'save_warmup' logic is complex; however, brms's conversion to stanfit
# means n_save should correctly reflect post-warmup draws.
# The keep_idx directly applies to the vectors of length n_draws_available.
fit_new$fit@sim$samples <- lapply(sf_sim$samples, function(chain_list) {
thinned_params <- lapply(chain_list, `[`, keep_idx)
sampler_params <- attr(chain_list, "sampler_params")
if (!is.null(sampler_params)) {
thinned_sampler_params <- lapply(sampler_params, `[`, keep_idx)
if ("divergent__" %in% names(thinned_sampler_params)) {
thinned_sampler_params$divergent__[is.na(thinned_sampler_params$divergent__)] <- 0
}
attributes(thinned_params) <- attributes(chain_list)
attr(thinned_params, "sampler_params") <- thinned_sampler_params
}
return(thinned_params)
})
} else if (backend == "cmdstanr") {
# cmdstanr backend stores samples as a list of data.frames.
# The keep_idx directly applies to the rows of the data frame.
fit_new$fit@sim$samples <- lapply(sf_sim$samples, function(chain_df) {
thinned_df <- chain_df[keep_idx, , drop = FALSE]
sampler_params_df <- attr(chain_df, "sampler_params")
if (!is.null(sampler_params_df)) {
thinned_sampler_params <- sampler_params_df[keep_idx, , drop = FALSE]
if ("divergent__" %in% colnames(thinned_sampler_params)) {
thinned_sampler_params$divergent__[is.na(thinned_sampler_params$divergent__)] <- 0
}
attr(thinned_df, "sampler_params") <- thinned_sampler_params
}
return(thinned_df)
})
}
# --- 3. Update metadata (same for both backends) ---
fit_new$fit@sim$warmup <- 0
fit_new$fit@sim$warmup2 <- rep(0, sf_sim$chains)
fit_new$fit@sim$iter <- n_draws_new
fit_new$fit@sim$n_save <- rep(n_draws_new, sf_sim$chains)
fit_new$fit@sim$permutation <- lapply(1:sf_sim$chains, function(i) 1:n_draws_new)
fit_new$fit@sim$thin <- sf_sim$thin * thin_by
fit_new$fit@stan_args <- lapply(fit_new$fit@stan_args, function(arg_list) {
arg_list$warmup <- 0
arg_list$iter <- n_draws_new
arg_list$thin <- if (!is.null(arg_list$thin)) arg_list$thin * thin_by else thin_by
return(arg_list)
})
[email protected] <- new.env(parent = emptyenv())
return(fit_new)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment