Last active
March 8, 2024 20:56
-
-
Save StaffanBetner/7c9fef5ba146db9dd26393a409bdadb8 to your computer and use it in GitHub Desktop.
Sample Bootstrap Weights within Stan (once for every iteration)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef DIRICHLET_RNG_WRAPPER_HPP | |
#define DIRICHLET_RNG_WRAPPER_HPP | |
#include <stan/math.hpp> | |
#include <boost/random/mersenne_twister.hpp> | |
#include <chrono> | |
#include <Eigen/Dense> | |
#include <iostream> | |
// Declare an integer to keep track of the iteration count | |
static int itct = 0; | |
// Increment the counter | |
inline void add_iter(std::ostream* pstream__) { | |
itct += 1; | |
} | |
// Retrieve the current count | |
inline int get_iter(std::ostream* pstream__) { | |
return itct; | |
} | |
// Generate Dirichlet draws, with iteration checking | |
Eigen::VectorXd dirichlet_rng_wrapper(const Eigen::VectorXd& alpha, std::ostream* pstream__) { | |
static Eigen::VectorXd last_draw = Eigen::VectorXd::Zero(alpha.size()); // Initialize with zeros | |
static int last_itct = -1; // Start with -1 to ensure it differs from itct initially | |
if (itct != last_itct) { | |
// It's a new iteration, generate new Dirichlet draws | |
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); | |
boost::random::mt19937 rng(seed); | |
last_draw = stan::math::dirichlet_rng(alpha, rng); | |
// Update the iteration counter | |
last_itct = itct; | |
} | |
// Increment the iteration count is handled outside this function | |
return last_draw; | |
} | |
#endif // DIRICHLET_RNG_WRAPPER_HPP |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// generated with brms 2.20.6 | |
functions { | |
void add_iter(); // ~*~THIS IS NEW~*~ | |
int get_iter(); // ~*~THIS IS NEW~*~ | |
vector dirichlet_rng_wrapper(vector alpha); // ~*~THIS IS NEW~*~ | |
} | |
data { | |
int<lower=1> N; // total number of observations | |
vector[N] Y; // response variable | |
array[N] int<lower=-1,upper=2> cens; // indicates censoring | |
} | |
transformed data { | |
vector[N] alpha = rep_vector(1.0, N); // Dirichlet parameters, all ones for uniform distribution ~*~THIS IS NEW~*~ | |
} | |
parameters { | |
real Intercept; // temporary intercept for centered predictors | |
real<lower=0> shape; // shape parameter | |
} | |
transformed parameters { | |
real lprior = 0; // prior contributions to the log posterior | |
lprior += student_t_lpdf(Intercept | 3, 6.3, 2.5); | |
lprior += gamma_lpdf(shape | 0.01, 0.01); | |
} | |
model { | |
// likelihood including constants | |
if (!prior_only) { | |
vector[N] weights = dirichlet_rng_wrapper(alpha)*N; // ~*~THIS IS NEW~*~ | |
// initialize linear predictor term | |
vector[N] mu = rep_vector(0.0, N); | |
mu += Intercept; | |
mu = exp(mu); | |
for (n in 1:N) { | |
// special treatment of censored data | |
if (cens[n] == 0) { | |
target += weights[n] * weibull_lpdf(Y[n] | shape, mu[n] / tgamma(1 + 1 / shape)); | |
} else if (cens[n] == 1) { | |
target += weights[n] * weibull_lccdf(Y[n] | shape, mu[n] / tgamma(1 + 1 / shape)); | |
} else if (cens[n] == -1) { | |
target += weights[n] * weibull_lcdf(Y[n] | shape, mu[n] / tgamma(1 + 1 / shape)); | |
} | |
} | |
} | |
// priors including constants | |
target += lprior; | |
} | |
generated quantities { | |
// actual population-level intercept | |
real b_Intercept = Intercept; | |
add_iter(); // update the counter each iteration -- ~*~THIS IS NEW~*~ | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--- | |
title: "Bootstrapping within Stan" | |
output: html_notebook | |
--- | |
```{r} | |
# Function available here: https://gist.github.com/StaffanBetner/d632bd70686aebd8e488bdc1a454c8e2 | |
source("https://gist.githubusercontent.com/StaffanBetner/d632bd70686aebd8e488bdc1a454c8e2/raw/") | |
pkg_load(tidyverse, rio, magrittr, janitor, qs, fwb, posterior, ggdist, brms, cmdstanr, rstan, here) | |
data("bearingcage", package = "fwb") | |
``` | |
```{r} | |
bearingcage | |
``` | |
Let's start with some classic example: | |
```{r} | |
fit <- survival::survreg(survival::Surv(hours, failure) ~ 1, | |
data = bearingcage, | |
dist = "weibull") | |
fit | |
``` | |
```{r} | |
summary(fit) | |
``` | |
```{r} | |
MASS::mvrnorm(n = 8000, mu = summary(fit)$table[,1], Sigma = vcov(fit), empirical = T) %>% rvar() -> | |
mle_normal_approx | |
mle_normal_approx[2] <- 1/exp(mle_normal_approx[2]) | |
mle_normal_approx %>% | |
setNames(c("log(eta)","beta")) %>% | |
t() %>% | |
as_tibble() %>% | |
mutate(method = "mle",.before = `log(eta)`) -> | |
mle_normal_approx | |
``` | |
```{r} | |
mle_normal_approx | |
``` | |
```{r} | |
weibull_est <- function(data, w) { | |
fit <- survival::survreg(survival::Surv(hours, failure) ~ 1, | |
data = data, weights = w, | |
dist = "weibull") | |
c("log(eta)" = unname((coef(fit))), beta = 1/fit$scale) | |
} | |
``` | |
```{r} | |
fwb_est <- fwb(bearingcage, statistic = weibull_est, | |
R = 8000, verbose = TRUE) %>% qcache("fwb_est") | |
``` | |
```{r} | |
(fwb_est$t %>% | |
rvar() %>% | |
t() %>% | |
as_tibble()%>% | |
mutate(method = "bootstrap", .before =`log(eta)`) %>% | |
bind_rows(mle_normal_approx) -> | |
estimates) | |
``` | |
```{r} | |
estimates %>% | |
pivot_longer(cols = -1) %>% | |
ggplot(aes(xdist=value, y=method, group=method))+ | |
facet_wrap(~name)+ | |
stat_slab() | |
``` | |
Let's add the Bayesian case(s) | |
```{r} | |
brm(formula = hours|cens(!failure)~1, | |
family = weibull, | |
data = bearingcage, | |
backend = "cmdstanr", | |
cores = 4, | |
iter = 3000, | |
warmup = 1000, | |
file = "fit_brm", | |
control= list(adapt_delta=.95)) -> | |
fit_brm | |
``` | |
```{r} | |
fit_brm %>% as_draws_rvars() -> fit_brm_rvars | |
``` | |
```{r} | |
(estimates %>% | |
add_row(method = "hmc", `log(eta)` = fit_brm_rvars$b_Intercept, beta = fit_brm_rvars$shape) -> | |
estimates) | |
``` | |
```{r} | |
estimates %>% | |
pivot_longer(cols = -1) %>% | |
ggplot(aes(xdist=value, y=method, group=method))+ | |
facet_wrap(~name)+ | |
stat_slab() | |
``` | |
# HMC bootstrapping (here be dragons) | |
```{r} | |
make_stancode(formula = hours|cens(!failure)+weights(1)~1, | |
family = weibull, | |
data = bearingcage, | |
save_model = "original_stan_code.stan") | |
``` | |
```{r} | |
cmdstan_model("modified_stan_code.stan", user_header = here('iterfuns.hpp')) -> modified_model | |
``` | |
```{r} | |
# RStan code: | |
# stan_model(stanc_ret = stanc("modified_stan_code.stan", | |
# allow_undefined = TRUE), | |
# includes = paste0('\n#include "', here('iterfuns.hpp'), '"\n')) -> | |
# outcome_model | |
``` | |
<!-- When supplying weights externally: cmdstanr seems to reset memory when it goes from warmup to sampling, so only max(iter_sampling, iter_warmup) bootstrap draws are needed. rstan shares memory over all chains and between warmup and sampling, and also uses one iteration for initial something --> | |
```{r} | |
make_standata(formula = hours|cens(!failure)+weights(1)~1, | |
family = weibull, | |
data=bearingcage)[-3] -> # [3] is weights | |
standata_bootstrap | |
``` | |
```{r} | |
modified_model$sample(data = standata_bootstrap, | |
chains = 4, | |
iter_warmup = 1000, | |
iter_sampling = 2000, | |
adapt_delta = 0.9995, | |
refresh = 50L, | |
parallel_chains = 3) %>% | |
qcache("outcome_samples_cmdstan") -> | |
outcome_samples_cmdstan | |
``` | |
NO DIVERGENCES!! HOLY MACARONI! | |
```{r} | |
outcome_samples_cmdstan$output_files() %>% | |
rstan::read_stan_csv() -> | |
rstan_fit | |
outcome_samples_cmdstan -> | |
attributes(rstan_fit)$CmdStanModel | |
modified_model_brms <- brm(formula = hours|cens(!failure)+weights(1)~1, | |
family = weibull, | |
data=bearingcage, | |
empty = TRUE) | |
rstan_fit -> | |
modified_model_brms$fit | |
rename_pars(modified_model_brms) -> | |
modified_model_brms | |
modified_model_brms %>% qcache("modified_model_brms") -> modified_model_brms | |
``` | |
```{r} | |
outcome_samples_cmdstan$draws() %>% as_draws_rvars() -> hmc_bootstrap_rvars | |
``` | |
```{r} | |
(estimates %>% | |
add_row(method = "hmc bootstrap", `log(eta)` = hmc_bootstrap_rvars$b_Intercept, beta = hmc_bootstrap_rvars$shape) -> | |
estimates) | |
``` | |
```{r} | |
estimates %>% | |
slice(2,1,3,4) %>% | |
mutate(method = method %>% factor(levels = method) %>% fct_rev) %>% | |
pivot_longer(cols = -1) %>% | |
ggplot(aes(xdist=value, y=method, group=method))+ | |
facet_wrap(~name, scales = "free_x")+ | |
stat_slab()+ | |
theme_ggdist()+labs(x=NULL, title="Bearing Cage data (Weibull model)", y="Estimation Method") | |
``` |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment