Last active
July 27, 2025 20:12
-
-
Save rubenarslan/3d2c6a8c6064baaeabc9dc218d0c4928 to your computer and use it in GitHub Desktop.
thin brmsfit (brms fit objects)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #' Thin a brmsfit object post-hoc | |
| #' | |
| #' This function thins a brmsfit object after fitting. It checks the backend | |
| #' used to fit the model ('rstan' or 'cmdstanr') and applies the correct | |
| #' thinning method, as they store samples in different internal structures. | |
| #' | |
| #' @param fit A brmsfit object. | |
| #' @param thin_by An integer factor by which to thin the post-warmup chains. | |
| #' @return A new, thinned brmsfit object. | |
| thin_brmsfit <- function(fit, thin_by) { | |
| if (!requireNamespace("brms", quietly = TRUE)) { | |
| stop("The 'brms' package is required.") | |
| } | |
| if (!inherits(fit, "brmsfit")) { | |
| stop("Input must be a brmsfit object.") | |
| } | |
| if (!is.numeric(thin_by) || thin_by <= 1 || thin_by %% 1 != 0) { | |
| stop("thin_by must be an integer greater than 1.") | |
| } | |
| backend <- fit$backend | |
| if (!backend %in% c("rstan", "cmdstanr")) { | |
| stop("This function only supports 'rstan' and 'cmdstanr' backends.", call. = FALSE) | |
| } | |
| fit_new <- fit | |
| # --- 1. Determine thinning indices --- | |
| sf_sim <- fit_new$fit@sim | |
| # --- THE KEY CORRECTION --- | |
| # The true number of post-warmup draws is in n_save, not calculated | |
| # from iter and warmup. This handles pre-thinned models correctly. | |
| n_draws_available <- sf_sim$n_save[1] | |
| if (thin_by > n_draws_available) { | |
| stop("thin_by factor is larger than the number of available post-warmup samples.") | |
| } | |
| keep_idx <- seq(from = 1, to = n_draws_available, by = thin_by) | |
| n_draws_new <- length(keep_idx) | |
| # --- 2. Apply thinning based on backend-specific structure --- | |
| if (backend == "rstan") { | |
| # rstan backend stores samples as a list of lists of vectors. | |
| # The 'save_warmup' logic is complex; however, brms's conversion to stanfit | |
| # means n_save should correctly reflect post-warmup draws. | |
| # The keep_idx directly applies to the vectors of length n_draws_available. | |
| fit_new$fit@sim$samples <- lapply(sf_sim$samples, function(chain_list) { | |
| thinned_params <- lapply(chain_list, `[`, keep_idx) | |
| sampler_params <- attr(chain_list, "sampler_params") | |
| if (!is.null(sampler_params)) { | |
| thinned_sampler_params <- lapply(sampler_params, `[`, keep_idx) | |
| if ("divergent__" %in% names(thinned_sampler_params)) { | |
| thinned_sampler_params$divergent__[is.na(thinned_sampler_params$divergent__)] <- 0 | |
| } | |
| attributes(thinned_params) <- attributes(chain_list) | |
| attr(thinned_params, "sampler_params") <- thinned_sampler_params | |
| } | |
| return(thinned_params) | |
| }) | |
| } else if (backend == "cmdstanr") { | |
| # cmdstanr backend stores samples as a list of data.frames. | |
| # The keep_idx directly applies to the rows of the data frame. | |
| fit_new$fit@sim$samples <- lapply(sf_sim$samples, function(chain_df) { | |
| thinned_df <- chain_df[keep_idx, , drop = FALSE] | |
| sampler_params_df <- attr(chain_df, "sampler_params") | |
| if (!is.null(sampler_params_df)) { | |
| thinned_sampler_params <- sampler_params_df[keep_idx, , drop = FALSE] | |
| if ("divergent__" %in% colnames(thinned_sampler_params)) { | |
| thinned_sampler_params$divergent__[is.na(thinned_sampler_params$divergent__)] <- 0 | |
| } | |
| attr(thinned_df, "sampler_params") <- thinned_sampler_params | |
| } | |
| return(thinned_df) | |
| }) | |
| } | |
| # --- 3. Update metadata (same for both backends) --- | |
| fit_new$fit@sim$warmup <- 0 | |
| fit_new$fit@sim$warmup2 <- rep(0, sf_sim$chains) | |
| fit_new$fit@sim$iter <- n_draws_new | |
| fit_new$fit@sim$n_save <- rep(n_draws_new, sf_sim$chains) | |
| fit_new$fit@sim$permutation <- lapply(1:sf_sim$chains, function(i) 1:n_draws_new) | |
| fit_new$fit@sim$thin <- sf_sim$thin * thin_by | |
| fit_new$fit@stan_args <- lapply(fit_new$fit@stan_args, function(arg_list) { | |
| arg_list$warmup <- 0 | |
| arg_list$iter <- n_draws_new | |
| arg_list$thin <- if (!is.null(arg_list$thin)) arg_list$thin * thin_by else thin_by | |
| return(arg_list) | |
| }) | |
| [email protected] <- new.env(parent = emptyenv()) | |
| return(fit_new) | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment