Skip to content

Instantly share code, notes, and snippets.

@MJacobs1985
Last active June 3, 2022 19:46
Show Gist options
  • Save MJacobs1985/434e98a381f7dec468a0d0fe4d65249e to your computer and use it in GitHub Desktop.
Save MJacobs1985/434e98a381f7dec468a0d0fe4d65249e to your computer and use it in GitHub Desktop.
rm(list = ls())
## Visualizing Tidy Draws from brms models
library(magrittr)
library(dplyr)
library(purrr)
library(forcats)
library(tidyr)
library(modelr)
library(ggdist)
library(tidybayes)
library(ggplot2)
library(cowplot)
library(rstan)
library(brms)
library(ggrepel)
library(RColorBrewer)
library(gganimate)
library(posterior)
library(ggpubr)
library(MCMCglmm)
library(ggstatsplot)
library(bayesplot)
theme_set(theme_tidybayes() + panel_border())
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
set.seed(5)
n = 10
n_condition = 5
ABC =
tibble(
condition = rep(c("A","B","C","D","E"), n),
response = rnorm(n * 5, c(0,1,2,1,-1), 0.5))
ggplot(ABC, aes(x=condition, y=response, fill=condition))+
geom_boxplot()+
theme_bw()
ggdotplotstats(
data = ABC,
y = condition,
x = response,
test.value = 2,
type = "robust",
title = "Response by condition",
xlab = "Response")
my_comparisons <- list( c("A", "B"),
c("A", "C"),
c("A", "D"),
c("A", "E"),
c("B", "C"),
c("B", "D"),
c("B", "E"),
c("C", "D"),
c("C", "E"),
c("D", "E"))
ggboxplot(ABC, x = "condition", y = "response",
fill = "condition")+
stat_compare_means(comparisons = my_comparisons)+
theme_bw()
get_prior(response ~ (1|condition),
data = ABC)
m = brm(
response ~ (1|condition),
data = ABC,
prior = c(
prior(normal(0, 1), class = Intercept),
prior(student_t(3, 0, 1), class = sd),
prior(student_t(3, 0, 1), class = sigma)),
control = list(adapt_delta = .99))
summary(m)
plot(m)
get_variables(m)
m %>%
spread_draws(r_condition[condition,term]) %>%
head(10)
m %>%
spread_draws(r_condition[c,t]) %>%
head(10)
m %>%
spread_draws(r_condition[condition,]) %>%
head(10)
m %>%
spread_draws(b_Intercept, sigma) %>%
head(10)
m %>%
spread_draws(b_Intercept, sigma) %>%
median_qi(b_Intercept, sigma)
m %>%
spread_draws(b_Intercept, sigma) %>%
median_qi()
ABC %>%
data_grid(condition) %>%
add_epred_draws(m, dpar = c("mu", "sigma")) %>%
sample_draws(100) %>%
ggplot(aes(y = condition)) +
stat_dist_slab(aes(dist = "norm", arg1 = mu, arg2 = sigma),
slab_color = "gray65", alpha = 1/10, fill = NA) +
geom_point(aes(x = response), data = ABC, shape = 21, fill = "#9ECAE1", size = 2)
ABC %>%
data_grid(condition) %>%
add_epred_draws(m, dpar = c("mu", "sigma")) %>%
ggplot(aes(x = condition)) +
stat_dist_slab(aes(dist = "norm", arg1 = mu, arg2 = sigma),
slab_color = "gray65", alpha = 1/10, fill = NA, data = . %>% sample_draws(100), scale = .5) +
stat_halfeye(aes(y = .epred), side = "bottom", scale = .5) +
geom_point(aes(y = response), data = ABC, shape = 21, fill = "#9ECAE1", size = 2, position = position_nudge(x = -.2))
grid = ABC %>%data_grid(condition)
means = grid %>%add_epred_draws(m)
preds = grid %>%add_predicted_draws(m)
ABC %>%
ggplot(aes(x = response, y = condition)) +
stat_halfeye(aes(x = .epred), scale = 0.6,
position = position_nudge(y = 0.175), data = means) +
stat_interval(aes(x = .prediction), data = preds) +
geom_point(data = ABC) +
scale_color_brewer()
m %>%
spread_draws(r_condition[condition,]) %>%
compare_levels(r_condition, by = condition) %>%
ungroup() %>%
mutate(condition = reorder(condition, r_condition)) %>%
ggplot(aes(y = condition, x = r_condition)) +
stat_halfeye() +
geom_vline(xintercept = 0, linetype = "dashed")
ABC %>%
data_grid(condition) %>%
add_epred_draws(m, dpar = c("mu", "sigma")) %>%
sample_draws(100) %>%
ggplot(aes(y = condition)) +
stat_dist_slab(aes(dist = "norm", arg1 = mu, arg2 = sigma),
slab_color = "gray65", alpha = 1/10, fill = NA) +
geom_point(aes(x = response), data = ABC, shape = 21, fill = "#9ECAE1", size = 2)
ABC %>%
data_grid(condition) %>%
add_epred_draws(m, dpar = c("mu", "sigma")) %>%
ggplot(aes(x = condition)) +
stat_dist_slab(aes(dist = "norm", arg1 = mu, arg2 = sigma),
slab_color = "gray65", alpha = 1/10, fill = NA, data = . %>% sample_draws(100), scale = .5) +
stat_halfeye(aes(y = .epred), side = "bottom", scale = .5) +
geom_point(aes(y = response), data = ABC, shape = 21, fill = "#9ECAE1", size = 2, position = position_nudge(x = -.2))
m %>%
spread_draws(r_condition[condition,]) %>%
compare_levels(r_condition, by = condition) %>%
ungroup() %>%
mutate(condition = reorder(condition, r_condition)) %>%
ggplot(aes(y = condition, x = r_condition)) +
stat_halfeye() +
geom_vline(xintercept = 0, linetype = "dashed")
m_linear = lm(response ~ condition, data = ABC)
summary(m_linear)
par(mfrow = c(1, 3))
plot(m_linear)
linear_results = m_linear %>%
emmeans::emmeans(~ condition) %>%
broom::tidy(conf.int = TRUE) %>%
mutate(model = "OLS")
linear_results<-as.data.frame(linear_results)
m_r = brm(
response ~ (1|condition),
data = ABC,
prior = c(
prior(normal(0, 1), class = Intercept),
prior(student_t(3, 0, 1), class = sd),
prior(student_t(3, 0, 1), class = sigma)),
control = list(adapt_delta = .99))
summary(m_r)
plot(m_r)
bayes_results = m_r %>%
spread_draws(r_condition[condition]) %>%
median_qi(estimate = r_condition) %>%
to_broom_names() %>%
mutate(model = "Bayes Random")
bayes_results<-as.data.frame(bayes_results)
bayes_results$condition<-as.character(bayes_results$condition)
ABC_stan <- stan_model(model_code = 'data {
int<lower=1> n;
int<lower=1> n_condition;
int<lower=1, upper=n_condition> condition[n];
real response[n];
}
parameters {
real overall_mean;
vector[n_condition] condition_zoffset;
real<lower=0> response_sd;
real<lower=0> condition_mean_sd;
}
transformed parameters {
vector[n_condition] condition_mean;
condition_mean = overall_mean + condition_zoffset * condition_mean_sd;
}
model {
response_sd ~ cauchy(0, 1); // => half-cauchy(0, 1)
condition_mean_sd ~ cauchy(0, 1); // => half-cauchy(0, 1)
overall_mean ~ normal(0, 5);
condition_zoffset ~ normal(0, 1); // => condition_mean ~ normal(overall_mean, condition_mean_sd)
for (i in 1:n) {
response[i] ~ normal(condition_mean[condition[i]], response_sd);
}
}')
m_f = sampling(ABC_stan,
data = compose_data(ABC),
control = list(adapt_delta = 0.99))
bayes_results2 = m_f %>%
recover_types(ABC)%>%
spread_draws(condition_mean[condition]) %>%
median_qi(estimate = condition_mean) %>%
to_broom_names() %>%
mutate(model = "Bayes Fixed")
bayes_results2<-as.data.frame(bayes_results2)
bayes_results2$condition<-as.character(bayes_results2$condition)
linear_results
bayes_results
bayes_results2
bind_rows(linear_results, bayes_results2) %>%
mutate(condition = fct_rev(condition)) %>%
ggplot(aes(y = condition, x = estimate, xmin = conf.low, xmax = conf.high, color = model)) +
geom_pointinterval(position = position_dodge(width = .3))
bind_rows(linear_results, bayes_results, bayes_results2) %>%
mutate(condition = fct_rev(condition)) %>%
ggplot(aes(y = condition, x = estimate, xmin = conf.low, xmax = conf.high, color = model)) +
geom_pointinterval(position = position_dodge(width = .1))
ms = sampling(ABC_stan,
data = compose_data(ABC),
control = list(adapt_delta = 0.99))
get_variables(ms)
ms %>%
recover_types(ABC) %>%
spread_draws(condition_mean[condition]) %>%
compare_levels(condition_mean, by = condition) %>%
head(10)
ms %>%
recover_types(ABC) %>%
spread_draws(condition_mean[condition]) %>%
compare_levels(condition_mean, by = condition) %>%
unspread_draws(condition_mean[condition]) %>%
head(10)
ms %>%
recover_types(ABC) %>%
spread_draws(condition_mean[condition]) %>%
compare_levels(condition_mean, by = condition) %>%
unspread_draws(condition_mean[condition], drop_indices = TRUE) %>%
bayesplot::mcmc_areas()
m_rst = rstanarm::stan_glm(response ~ condition,
data = ABC)
summary(m_rst)
plot(m_rst)
bayesplot::mcmc_dens(m_rst)
bayesplot::mcmc_combo(m_rst)
m_rst %>%
recover_types(ABC) %>%
emmeans::emmeans( ~ condition) %>%
gather_emmeans_draws() %>%
median_qi()
m_rst %>%
recover_types(ABC) %>%
emmeans::emmeans( ~ condition) %>%
emmeans::contrast(method = "pairwise") %>%
gather_emmeans_draws() %>%
median_qi()
m_rst %>%
recover_types(ABC) %>%
emmeans::emmeans( ~ condition) %>%
emmeans::contrast(method = "pairwise") %>%
gather_emmeans_draws() %>%
ggplot(aes(x = .value, y = contrast)) +
stat_halfeye()
m%>%
recover_types(ABC) %>%
add_epred_draws(m_rst, dpar = c("mu", "sigma")) %>%
sample_draws(100) %>%
ggplot(aes(y = condition)) +
stat_dist_slab(aes(dist = "norm", arg1 = mu, arg2 = sigma),
slab_color = "gray65", alpha = 1/10, fill = NA) +
geom_point(aes(x = response), data = ABC, shape = 21, fill = "#9ECAE1", size = 2)
ABC %>%
data_grid(condition) %>%
add_epred_draws(m, dpar = c("mu", "sigma")) %>%
ggplot(aes(x = condition)) +
stat_dist_slab(aes(dist = "norm", arg1 = mu, arg2 = sigma),
slab_color = "gray65", alpha = 1/10, fill = NA, data = . %>% sample_draws(100), scale = .5
) +
stat_halfeye(aes(y = .epred), side = "bottom", scale = .5) +
geom_point(aes(y = response), data = ABC, shape = 21, fill = "#9ECAE1", size = 2, position = position_nudge(x = -.2))
MCMCglmm(response ~ condition,
data = as.data.frame(ABC))%>%
recover_types(ABC) %>%
emmeans::emmeans( ~ condition, data = ABC) %>%
emmeans::contrast(method = "pairwise") %>%
gather_emmeans_draws() %>%
ggplot(aes(x = .value, y = contrast)) +
stat_halfeye()
bayesfit = brm(
response ~ (1|condition),
data = ABC,
prior = c(
prior(normal(10, 2), class = Intercept),
prior(gamma(3, 4), class = sd),
prior(gamma(2, 1), class = sigma)),
control = list(adapt_delta = .99),
chains=4, cores=6)
summary(bayesfit)
plot(bayesfit)
pp_check(bayesfit, ndraws=200)
pp_check(bayesfit, type = "error_hist", ndraws = 11)
pp_check(bayesfit, type = "scatter_avg", ndraws = 100)
pp_check(bayesfit, type = "stat_2d")
pp_check(bayesfit, type = "loo_pit")
yrep<-posterior_predict(bayesfit)
y<-ABC$response
ppc_dens_overlay(y, yrep[1:25,])
ppc_ecdf_overlay(y, yrep[sample(nrow(yrep), 25), ])
ppc_hist(y, yrep[1:8, ])
ppc_boxplot(y, yrep[1:8, ])
ppc_dens(y, yrep[200:202, ])
ppc_freqpoly(y, yrep[1:3,], alpha = 0.1, size = 1, binwidth = 5)
ppc_stat(y, yrep)
q25 <- function(y) quantile(y, 0.25)
ppc_stat(y, yrep, stat = "q25")
ppc_stat(y, yrep, stat = function(y) quantile(y, 0.25))
bayesplot_theme_set(ggplot2::theme_linedraw())
color_scheme_set("viridisE")
ppc_stat_2d(y, yrep, stat = c("mean", "sd"))
bayesplot_theme_set(ggplot2::theme_grey())
color_scheme_set("brewer-Paired")
ppc_stat_2d(y, yrep, stat = c("median", "mad"))
theme_set(theme_bw())
color_scheme_set("brightblue")
ppc_intervals(y, yrep)
ppc_intervals(y, yrep, size = 1.5, fatten = 0)
ppc_ribbon(y, yrep)
ppc_ribbon(y, yrep, y_draw = "points")
ppc_ribbon(y, yrep, y_draw = "both")
color_scheme_set("gray")
ppc_intervals(y, yrep, prob = 0.5) +
ggplot2::scale_x_continuous(
labels = rownames(ABC),
breaks = 1:nrow(ABC)) +
xaxis_text(angle = -70, vjust = 1, hjust = 0) +
xaxis_title(FALSE)
ppc_error_hist(y, yrep[1:3, ])
ppc_error_scatter(y, yrep[10:14, ])
ppc_error_scatter_avg(y, yrep)
loo1 <- loo(bayesfit, save_psis = TRUE, cores = 4)
psis1 <- loo1$psis_object
lw <- weights(psis1)
ppc_loo_pit_overlay(y, yrep, lw = lw)
ppc_loo_pit_qq(y, yrep, lw = lw)
ppc_loo_pit_qq(y, yrep, lw = lw, compare = "normal")
keep_obs <- 1:30
ppc_loo_intervals(y, yrep, psis_object = psis1, subset = keep_obs)
ppc_loo_intervals(y, yrep, psis_object = psis1, subset = keep_obs,
order = "median")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment