Created
October 23, 2024 10:35
-
-
Save jrosell/ec36c67df99562048e439ebbcbed0f8e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Original code: https://www.andrewheiss.com/blog/2022/09/26/guide-visualizing-types-posteriors/ | |
# Preparations ----------------------------------------------------------------- | |
cat("repos: ", getOption("repos"), "\n") | |
if (!requireNamespace("rlang", quietly = TRUE)) stop("Please, install.packages('rlang')") | |
if (!requireNamespace("pak", quietly = TRUE)) stop("Please, install.packages('pak')") | |
pkgs <- rlang::chr( | |
"tidyverse", # ggplot, dplyr, purrr and friends | |
"patchwork", # Combine ggplot plots | |
"ggtext", # Fancier text in ggplot plots | |
"scales", # Labeling functions | |
"brms", # Bayesian modeling through Stan | |
"tidybayes", # Manipulate Stan objects in a tidy way | |
"marginaleffects", # Calculate marginal effects | |
"modelr", # For quick model grids | |
"extraDistr", # For dprop() beta distribution with mu/phi | |
"distributional", # For plotting distributions with ggdist | |
"palmerpenguins", # Penguins dataset | |
"kableExtra", # For nicer tables | |
"MetBrewer", # Colors | |
"cmdstanr", # Stan backend | |
"broom.mixed", # Extract tidy information from mixed models | |
) | |
rlang::check_installed(pkgs) | |
purrr::walk(pkgs, library, character.only = TRUE, quietly = TRUE) | |
options(mc.cores = 4, brms.backend = "cmdstanr") | |
cmdstanr::install_cmdstan() | |
set.seed(1234) | |
bayes_seed <- 1234 | |
# Custom ggplot themes to make pretty plots | |
# Get Roboto Condensed at https://fonts.google.com/specimen/Roboto+Condensed | |
# Get Roboto Mono at https://fonts.google.com/specimen/Roboto+Mono | |
clrs <- MetBrewer::met.brewer("Java") | |
theme_pred <- function() { | |
theme_minimal(base_family = "Roboto Condensed") + | |
theme(panel.grid.minor = element_blank(), | |
plot.background = element_rect(fill = "white", color = NA), | |
plot.title = element_text(face = "bold"), | |
strip.text = element_text(face = "bold"), | |
strip.background = element_rect(fill = "grey80", color = NA), | |
axis.title.x = element_text(hjust = 0), | |
axis.title.y = element_text(hjust = 0), | |
legend.title = element_text(face = "bold")) | |
} | |
theme_pred_dist <- function() { | |
theme_pred() + | |
theme(plot.title = element_markdown(family = "Roboto Condensed", face = "plain"), | |
plot.subtitle = element_text(family = "Roboto Mono", size = rel(0.9), hjust = 0), | |
axis.text.y = element_blank(), | |
panel.grid.major.y = element_blank(), | |
panel.grid.minor.y = element_blank()) | |
} | |
theme_pred_range <- function() { | |
theme_pred() + | |
theme(plot.title = element_markdown(family = "Roboto Condensed", face = "plain"), | |
plot.subtitle = element_text(family = "Roboto Mono", size = rel(0.9), hjust = 0), | |
panel.grid.minor.y = element_blank()) | |
} | |
update_geom_defaults("text", list(family = "Roboto Condensed", lineheight = 1)) | |
# Data ------------------------------------------------------------------------- | |
penguins <- penguins |> | |
drop_na(sex) |> | |
mutate(is_gentoo = species == "Gentoo") |> | |
mutate(bill_ratio = bill_depth_mm / bill_length_mm) |> | |
print() | |
# Normal Gaussian model -------------------------------------------------------- | |
ggplot(penguins, aes(x = flipper_length_mm, y = body_mass_g)) + | |
geom_point(size = 1, alpha = 0.7) + | |
geom_smooth(method = "lm", color = clrs[5], se = FALSE) + | |
scale_y_continuous(labels = label_comma()) + | |
coord_cartesian(ylim = c(2000, 6000)) + | |
labs(x = "Flipper length (mm)", y = "Body mass (g)") + | |
theme_pred() | |
model_normal <- brm( | |
bf(body_mass_g ~ flipper_length_mm), | |
family = gaussian(), | |
data = penguins | |
) | |
# Calculate posterior distribution and get the means for the parameters -------- | |
broom.mixed::tidy(model_normal) |> | |
bind_cols(parameter = c("α", "β", "σ")) |> | |
select(parameter, term, estimate, std.error, conf.low, conf.high) | |
# # A tibble: 3 × 6 | |
# parameter term estimate std.error conf.low conf.high | |
# <chr> <chr> <dbl> <dbl> <dbl> <dbl> | |
# 1 α (Intercept) -5874. 314. -6488. -5266. | |
# 2 β flipper_length_mm 50.2 1.56 47.1 53.2 | |
# 3 σ sd__Observation 395. 15.5 366. 427. | |
# Prediction of the outcome based on a single value of flipper lengths --------- | |
penguins_avg_flipper <- penguins |> | |
summarize(flipper_length_mm = mean(flipper_length_mm)) |> | |
print() | |
normal_linpred <- model_normal |> | |
linpred_draws(newdata = penguins_avg_flipper) |> | |
print() | |
normal_epred <- model_normal |> | |
epred_draws(newdata = penguins_avg_flipper) |> | |
print() | |
normal_predicted <- model_normal |> | |
predicted_draws( | |
newdata = penguins_avg_flipper, | |
seed = 12345 | |
) |> | |
print() | |
summary_normal_linpred <- normal_linpred |> | |
ungroup() |> | |
summarize(across(.linpred, lst(mean, sd, median), .names = "{.fn}")) | |
summary_normal_epred <- normal_epred |> | |
ungroup() |> | |
summarize(across(.epred, lst(mean, sd, median), .names = "{.fn}")) | |
summary_normal_predicted <- normal_predicted |> | |
ungroup() |> | |
summarize(across(.prediction, lst(mean, sd, median), .names = "{.fn}")) | |
tribble( | |
~Function, ~`Model element`, | |
"<code>posterior_linpred()</code>", "\\(\\mu\\) in the model", | |
"<code>posterior_epred()</code>", "\\(\\operatorname{E(y)}\\) and \\(\\mu\\) in the model", | |
"<code>posterior_predict()</code>", "Random draws from posterior \\(\\operatorname{Normal}(\\mu_i, \\sigma)\\)" | |
) |> | |
bind_cols(bind_rows(summary_normal_linpred, summary_normal_epred, summary_normal_predicted)) |> | |
kbl(escape = FALSE) |> | |
kable_styling() | |
p1 <- ggplot(normal_linpred, aes(x = .linpred)) + | |
stat_halfeye(fill = clrs[3]) + | |
scale_x_continuous(labels = label_comma()) + | |
coord_cartesian(xlim = c(4100, 4300)) + | |
labs(x = "Body mass (g)", y = NULL, | |
title = "**Linear predictor** <span style='font-size: 14px;'>*µ* in the model</span>", | |
subtitle = "posterior_linpred(..., tibble(flipper_length_mm = 201))") + | |
theme_pred_dist() + | |
theme(plot.title = element_markdown()) | |
p2 <- ggplot(normal_epred, aes(x = .epred)) + | |
stat_halfeye(fill = clrs[2]) + | |
scale_x_continuous(labels = label_comma()) + | |
coord_cartesian(xlim = c(4100, 4300)) + | |
labs(x = "Body mass (g)", y = NULL, | |
title = "**Expectation of the posterior** <span style='font-size: 14px;'>E[*y*] and *µ* in the model</span>", | |
subtitle = "posterior_epred(..., tibble(flipper_length_mm = 201))") + | |
theme_pred_dist() | |
p3 <- ggplot(normal_predicted, aes(x = .prediction)) + | |
stat_halfeye(fill = clrs[1]) + | |
scale_x_continuous(labels = label_comma()) + | |
coord_cartesian(xlim = c(2900, 5500)) + | |
labs(x = "Body mass (g)", y = NULL, | |
title = "**Posterior predictions** <span style='font-size: 14px;'>Random draws from posterior Normal(*µ*, *σ*)</span>", | |
subtitle = "posterior_predict(..., tibble(flipper_length_mm = 201))") + | |
theme_pred_dist() | |
(p1 / plot_spacer() / p2 / plot_spacer() / p3) + | |
plot_layout(heights = c(0.3, 0.05, 0.3, 0.05, 0.3)) + | |
plot_annotation("For posterior_linpred() and posterior_epred(), the standard error is tiny and the range of plausible predicted values is really narrow.\nFor posterior_predict(), the standard error is substantially bigger, and the corresponding range of predicted values is huge.") | |
# Both posterior_linpred() and posterior_epred() correspond to the mu part of the model | |
# The distribution of the part of the model here does not incorporate information about sigma. | |
# That’s why the distribution is so narrow. | |
linpred_manual <- model_normal |> | |
spread_draws(b_Intercept, b_flipper_length_mm) |> | |
mutate(mu = b_Intercept + | |
(b_flipper_length_mm * penguins_avg_flipper$flipper_length_mm)) | |
p1_manual <- linpred_manual |> | |
ggplot(aes(x = mu)) + | |
stat_halfeye(fill = colorspace::lighten(clrs[3], 0.5)) + | |
scale_x_continuous(labels = label_comma()) + | |
coord_cartesian(xlim = c(4100, 4300)) + | |
labs(x = "Body mass (g)", y = NULL, | |
title = "**Linear predictor** <span style='font-size: 14px;'>*µ* in the model</span>", | |
subtitle = "b_Intercept + (b_flipper_length_mm * 201)") + | |
theme_pred_dist() + | |
theme(plot.title = element_markdown()) | |
p1_manual | p1 | |
# The results from posterior_predict(), on the other hand, correspond to the part of the model. | |
# Officially, they are draws from a random normal distribution using both the estimated mu and the estimated sigma. These results contain the full uncertainty of the posterior distribution of penguin weight. | |
set.seed(12345) # To get the same results as posterior_predict() from earlier | |
postpred_manual <- model_normal |> | |
spread_draws(b_Intercept, b_flipper_length_mm, sigma) |> | |
mutate(mu = b_Intercept + | |
(b_flipper_length_mm * | |
penguins_avg_flipper$flipper_length_mm), # This is posterior_linpred() | |
y_new = rnorm(n(), mean = mu, sd = sigma)) # This is posterior_predict() | |
p3_manual <- postpred_manual |> | |
ggplot(aes(x = y_new)) + | |
stat_halfeye(fill = colorspace::lighten(clrs[1], 0.5)) + | |
scale_x_continuous(labels = label_comma()) + | |
coord_cartesian(xlim = c(2900, 5500)) + | |
labs(x = "Body mass (g)", y = NULL, | |
title = "**Posterior predictions** <span style='font-size: 14px;'>Random draws from posterior Normal(*µ*, *σ*)</span>", | |
subtitle = "rnorm(b_Intercept + (b_flipper_length_mm * 201), sigma)") + | |
theme_pred_dist() + | |
theme(plot.title = element_markdown()) | |
p3_manual | p3 | |
# Plug a flipper length of 201 mm into the posterior estimates of the intercept and slope to calculate the part of the model | |
epred_manual <- model_normal |> | |
spread_draws(b_Intercept, b_flipper_length_mm, sigma) |> | |
mutate( | |
mu = b_Intercept + (b_flipper_length_mm * penguins_avg_flipper$flipper_length_mm), | |
y_new = rnorm(n(), mean = mu, sd = sigma) | |
) |> | |
print() | |
# This is posterior_epred() | |
epred_manual |> | |
summarize(epred = mean(y_new)) | |
# It's essentially the same as the actual posterior_epred() | |
normal_epred |> | |
ungroup() |> | |
summarize(epred = mean(.epred)) | |
epred_manual <- model_normal |> | |
spread_draws(b_Intercept, b_flipper_length_mm, sigma) |> | |
mutate(mu = b_Intercept + | |
(b_flipper_length_mm * | |
penguins_avg_flipper$flipper_length_mm), # This is posterior_linpred() | |
y_new = rnorm(n(), mean = mu, sd = sigma)) # This is posterior_predict() | |
# This is posterior_epred() | |
epred_manual |> | |
summarize(epred = mean(y_new)) | |
## # A tibble: 1 × 1 | |
## epred | |
## <dbl> | |
## 1 4204. | |
# It's essentially the same as the actual posterior_epred() | |
normal_epred |> | |
ungroup() |> | |
summarize(epred = mean(.epred)) | |
## # A tibble: 1 × 1 | |
## epred | |
## <dbl> | |
## 1 4206. | |
# Posterior predictions across a range of possible flipper lengths ------------- | |
p1 <- penguins |> | |
data_grid(flipper_length_mm = seq_range(flipper_length_mm, n = 100)) |> | |
add_linpred_draws(model_normal, ndraws = 100) |> | |
ggplot(aes(x = flipper_length_mm)) + | |
stat_lineribbon(aes(y = .linpred), .width = 0.95, | |
alpha = 0.5, color = clrs[3], fill = clrs[3]) + | |
geom_point(data = penguins, aes(y = body_mass_g), size = 1, alpha = 0.7) + | |
scale_y_continuous(labels = label_comma()) + | |
coord_cartesian(ylim = c(2000, 6000)) + | |
labs(x = "Flipper length (mm)", y = "Body mass (g)", | |
title = "**Linear predictor** <span style='font-size: 14px;'>*µ* in the model</span>", | |
subtitle = "posterior_linpred()") + | |
theme_pred_range() | |
p2 <- penguins |> | |
data_grid(flipper_length_mm = seq_range(flipper_length_mm, n = 100)) |> | |
add_epred_draws(model_normal, ndraws = 100) |> | |
ggplot(aes(x = flipper_length_mm)) + | |
stat_lineribbon(aes(y = .epred), .width = 0.95, | |
alpha = 0.5, color = clrs[2], fill = clrs[2]) + | |
geom_point(data = penguins, aes(y = body_mass_g), size = 1, alpha = 0.7) + | |
scale_y_continuous(labels = label_comma()) + | |
coord_cartesian(ylim = c(2000, 6000)) + | |
labs(x = "Flipper length (mm)", y = "Body mass (g)", | |
title = "**Expectation of the posterior** <span style='font-size: 14px;'>E[*y*] and *µ* in the model</span>", | |
subtitle = "posterior_epred()") + | |
theme_pred_range() | |
p3 <- penguins |> | |
data_grid(flipper_length_mm = seq_range(flipper_length_mm, n = 100)) |> | |
add_predicted_draws(model_normal, ndraws = 100) |> | |
ggplot(aes(x = flipper_length_mm)) + | |
stat_lineribbon(aes(y = .prediction), .width = 0.95, | |
alpha = 0.5, color = clrs[1], fill = clrs[1]) + | |
geom_point(data = penguins, aes(y = body_mass_g), size = 1, alpha = 0.7) + | |
scale_y_continuous(labels = label_comma()) + | |
coord_cartesian(ylim = c(2000, 6000)) + | |
labs(x = "Flipper length (mm)", y = "Body mass (g)", | |
title = "**Posterior predictions** <span style='font-size: 14px;'>Random draws from posterior Normal(*µ*, *σ*)</span>", | |
subtitle = "posterior_predict()") + | |
theme_pred_range() | |
(p1 / plot_spacer() / p2 / plot_spacer() / p3) + | |
plot_layout(heights = c(0.3, 0.05, 0.3, 0.05, 0.3)) | |
# Logistic regression example -------------------------------------------------- | |
# Generalized linear models: logistic, probit, ordered logistic, exponential, Poisson, negative binomial, etc. | |
# Use special link functions (e.g. logit, log, etc.) to transform the likelihood of an outcome into a scale that is more amenable to linear regression. | |
ggplot(penguins, aes(x = bill_length_mm, y = as.numeric(is_gentoo))) + | |
geom_dots(aes(side = ifelse(is_gentoo, "bottom", "top")), | |
pch = 19, color = "grey20", scale = 0.2) + | |
geom_smooth(method = "glm", method.args = list(family = binomial(link = "logit")), | |
color = clrs[5], se = FALSE) + | |
scale_y_continuous(labels = label_percent()) + | |
labs(x = "Bill length (mm)", y = "Probability of being a Gentoo") + | |
theme_pred() | |
model_logit <- brm( | |
bf(is_gentoo ~ bill_length_mm), | |
family = bernoulli(link = "logit"), | |
data = penguins | |
) | |
# Make a little dataset of just the average bill length and Extract different types of posteriors | |
penguins_avg_bill <- penguins |> | |
summarize(bill_length_mm = mean(bill_length_mm)) | |
logit_linpred <- model_logit |> | |
linpred_draws(newdata = penguins_avg_bill) | |
logit_epred <- model_logit |> | |
epred_draws(newdata = penguins_avg_bill) | |
logit_predicted <- model_logit |> | |
predicted_draws(newdata = penguins_avg_bill) | |
summary_logit_linpred <- logit_linpred |> | |
ungroup() |> | |
summarize(across(.linpred, lst(mean, sd, median), .names = "{.fn}")) | |
summary_logit_epred <- logit_epred |> | |
ungroup() |> | |
summarize(across(.epred, lst(mean, sd, median), .names = "{.fn}")) | |
summary_logit_predicted <- logit_predicted |> | |
ungroup() |> | |
summarize(across(.prediction, lst(mean), .names = "{.fn}")) | |
tribble( | |
~Function, ~`Model element`, ~Values, | |
"<code>posterior_linpred()</code>", "\\(\\operatorname{logit}(\\pi)\\) in the model", "Logits or log odds", | |
"<code>posterior_linpred(transform = TRUE)</code> or <code>posterior_epred()</code>", "\\(\\operatorname{E(y)}\\) and \\(\\pi\\) in the model", "Probabilities", | |
"<code>posterior_predict()</code>", "Random draws from posterior \\(\\operatorname{Binomial}(1, \\pi)\\)", "0s and 1s" | |
) |> | |
bind_cols(bind_rows(summary_logit_linpred, summary_logit_epred, summary_logit_predicted)) |> | |
kbl(escape = FALSE) |> | |
kable_styling() | |
# The results from posterior_epred() and posterior_linpred() are on different scales | |
# posterior_epred() provides results on the probability scale, un-logiting and back-transforming the results from posterior_linpred() (which provides results on the logit scale). | |
# Technically, posterior_epred() isn’t just the back-transformed linear predictor (if you want that, you can use posterior_linpred(..., transform = TRUE)). More formally, posterior_epred() returns the expected values of the posterior, or , or the average of the posterior’s averages. But as with Gaussian regression, for mathy reasons this average-of-averages happens to be the same as the back-transformed | |
p1 <- ggplot(logit_linpred, aes(x = .linpred)) + | |
stat_halfeye(fill = clrs[3]) + | |
coord_cartesian(xlim = c(-1.5, -0.2)) + | |
labs(x = "Logit-transformed probability of being a Gentoo", y = NULL, | |
title = "**Linear predictor** <span style='font-size: 14px;'>logit(*π*) in the model</span>", | |
subtitle = "posterior_linpred(..., tibble(bill_length_mm = 44))") + | |
theme_pred_dist() | |
p2 <- ggplot(logit_epred, aes(x = .epred)) + | |
stat_halfeye(fill = clrs[2]) + | |
scale_x_continuous(labels = label_percent()) + | |
coord_cartesian(xlim = c(0.2, 0.45)) + | |
labs(x = "Probability of being a Gentoo", y = NULL, | |
title = "**Expectation of the posterior** <span style='font-size: 14px;'>E[*y*] and *π* in the model</span>", | |
subtitle = "posterior_epred(..., tibble(bill_length_mm = 44))") + | |
theme_pred_dist() | |
p3 <- logit_predicted |> | |
count(is_gentoo = .prediction) |> | |
mutate(prop = n / sum(n), | |
prop_nice = label_percent(accuracy = 0.1)(prop)) |> | |
ggplot(aes(x = factor(is_gentoo), y = n)) + | |
geom_col(fill = clrs[1]) + | |
geom_text(aes(label = prop_nice), nudge_y = -300, color = "white", size = 3) + | |
scale_x_discrete(labels = c("Not Gentoo (0)", "Gentoo (1)")) + | |
scale_y_continuous(labels = label_comma()) + | |
labs(x = "Prediction of being a Gentoo", y = NULL, | |
title = "**Posterior predictions** <span style='font-size: 14px;'>Random draws from posterior Binomial(1, *π*)</span>", | |
subtitle = "posterior_predict(..., tibble(bill_length_mm = 44))") + | |
theme_pred_range() + | |
theme(panel.grid.major.x = element_blank()) | |
(p1 / plot_spacer() / p2 / plot_spacer() / p3) + | |
plot_layout(heights = c(0.3, 0.05, 0.3, 0.05, 0.3)) | |
# Posterior predictions across a range of bill lengths ----------------------- | |
pred_logit_gentoo <- tibble(bill_length_mm = c(35, 45, 55)) |> | |
add_predicted_draws(model_logit, ndraws = 500) | |
pred_logit_gentoo_summary <- pred_logit_gentoo |> | |
group_by(bill_length_mm) |> | |
summarize(prop = mean(.prediction), | |
prop_nice = paste0(label_percent(accuracy = 0.1)(prop), "\nGentoos")) | |
p1 <- penguins |> | |
data_grid(bill_length_mm = seq_range(bill_length_mm, n = 100)) |> | |
add_linpred_draws(model_logit, ndraws = 100) |> | |
ggplot(aes(x = bill_length_mm)) + | |
stat_lineribbon(aes(y = .linpred), .width = 0.95, | |
alpha = 0.5, color = clrs[3], fill = clrs[3]) + | |
coord_cartesian(xlim = c(30, 60)) + | |
labs(x = "Bill length (mm)", y = "Logit-transformed\nprobability of being a Gentoo", | |
title = "**Linear predictor posterior** <span style='font-size: 14px;'>logit(*π*) in the model</span>", | |
subtitle = "posterior_linpred()") + | |
theme_pred_range() | |
p2 <- penguins |> | |
data_grid(bill_length_mm = seq_range(bill_length_mm, n = 100)) |> | |
add_epred_draws(model_logit, ndraws = 100) |> | |
ggplot(aes(x = bill_length_mm)) + | |
geom_dots(data = penguins, aes(y = as.numeric(is_gentoo), x = bill_length_mm, | |
side = ifelse(is_gentoo, "bottom", "top")), | |
pch = 19, color = "grey20", scale = 0.2) + | |
stat_lineribbon(aes(y = .epred), .width = 0.95, | |
alpha = 0.5, color = clrs[2], fill = clrs[2]) + | |
scale_y_continuous(labels = label_percent()) + | |
coord_cartesian(xlim = c(30, 60)) + | |
labs(x = "Bill length (mm)", y = "Probability of\nbeing a Gentoo", | |
title = "**Expectation of the posterior** <span style='font-size: 14px;'>E[*y*] and *π* in the model</span>", | |
subtitle = "posterior_epred()") + | |
theme_pred_range() | |
p3 <- ggplot(pred_logit_gentoo, aes(x = factor(bill_length_mm), y = .prediction)) + | |
geom_point(position = position_jitter(width = 0.2, height = 0.1, seed = 1234), | |
size = 0.75, alpha = 0.3, color = clrs[1]) + | |
geom_text(data = pred_logit_gentoo_summary, aes(y = 0.5, label = prop_nice), size = 3) + | |
scale_y_continuous(breaks = c(0, 1), labels = c("Not\nGentoo", "Gentoo")) + | |
labs(x = "Bill length (mm)", y = "Prediction of\nbeing a Gentoo", | |
title = "**Posterior predictions** <span style='font-size: 14px;'>Random draws from posterior Binomial(1, *π*)</span>", | |
subtitle = "posterior_predict()") + | |
theme_pred_range() + | |
theme(panel.grid.major.x = element_blank(), | |
panel.grid.major.y = element_blank(), | |
axis.text.y = element_text(angle = 90, hjust = 0.5)) | |
(p1 / p2 / p3) | |
# Beta regression example ------------------------------------------------------ | |
# Regression models often focus solely on the location parameter of the model (e.g., in ; in ). However, it is also possible to specify separate predictors for the scale or shape parameters of models (e.g., in , in ). In the world of brms, these are called distributional models. | |
ggplot(penguins, aes(x = flipper_length_mm, y = bill_ratio)) + | |
geom_point(size = 1, alpha = 0.7) + | |
geom_smooth(method = "lm", color = clrs[5], se = FALSE) + | |
labs(x = "Flipper length (mm)", y = "Ratio of bill depth / bill length") + | |
theme_pred() | |
model_beta <- brm( | |
bf(bill_ratio ~ flipper_length_mm, | |
phi ~ flipper_length_mm), | |
family = Beta(), | |
init = "0", | |
data = penguins, | |
prior = c(prior(normal(0, 1), class = "b"), | |
prior(exponential(1), class = "b", dpar = "phi", lb = 0)) | |
) | |
penguins_avg_flipper <- penguins |> | |
summarize(flipper_length_mm = mean(flipper_length_mm)) | |
# Extract different types of posteriors | |
beta_linpred <- model_beta |> | |
linpred_draws(newdata = penguins_avg_flipper) | |
beta_linpred_phi <- model_beta |> | |
linpred_draws(newdata = penguins_avg_flipper, dpar = "phi") | |
beta_linpred_trans <- model_beta |> | |
linpred_draws(newdata = penguins_avg_flipper, transform = TRUE) | |
beta_linpred_phi_trans <- model_beta |> | |
linpred_draws(newdata = penguins_avg_flipper, dpar = "phi", transform = TRUE) | |
beta_epred <- model_beta |> | |
epred_draws(newdata = penguins_avg_flipper) | |
beta_predicted <- model_beta |> | |
predicted_draws(newdata = penguins_avg_flipper) | |
# Beta posteriors -------------------------------- | |
summary_beta_linpred <- beta_linpred |> | |
ungroup() |> | |
summarize(across(.linpred, lst(mean, sd, median), .names = "{.fn}")) | |
summary_beta_linpred_phi <- beta_linpred_phi |> | |
ungroup() |> | |
summarize(across(phi, lst(mean, sd, median), .names = "{.fn}")) | |
summary_beta_linpred_phi_trans <- beta_linpred_phi_trans |> | |
ungroup() |> | |
summarize(across(phi, lst(mean, sd, median), .names = "{.fn}")) | |
summary_beta_epred <- beta_epred |> | |
ungroup() |> | |
summarize(across(.epred, lst(mean, sd, median), .names = "{.fn}")) | |
summary_beta_predicted <- beta_predicted |> | |
ungroup() |> | |
summarize(across(.prediction, lst(mean, sd, median), .names = "{.fn}")) | |
tribble( | |
~Function, ~`Model element`, ~Values, | |
"<code>posterior_linpred()</code>", "\\(\\operatorname{logit}(\\mu)\\) in the model", "Logits or log odds", | |
"<code>posterior_linpred(transform = TRUE)</code> or <code>posterior_epred()</code>", "\\(\\operatorname{E(y)}\\) and \\(\\mu\\) in the model", "Probabilities", | |
'<code>posterior_linpred(dpar = "phi")</code>', "\\(\\log(\\phi)\\) in the model", "Logged precision values", | |
'<code>posterior_linpred(dpar = "phi", transform = TRUE)</code>', "\\(\\phi\\) in the model", "Unlogged precision values", | |
"<code>posterior_predict()</code>", "Random draws from posterior \\(\\operatorname{Beta}(\\mu, \\phi)\\)", "Values between 0–1" | |
) |> | |
bind_cols(bind_rows(summary_beta_linpred, summary_beta_epred, | |
summary_beta_linpred_phi, summary_beta_linpred_phi_trans, | |
summary_beta_predicted)) |> | |
kbl(escape = FALSE) |> | |
kable_styling() | |
p1 <- ggplot(beta_linpred, aes(x = .linpred)) + | |
stat_halfeye(fill = clrs[3]) + | |
labs(x = "Logit-scale ratio of bill depth / bill length", y = NULL, | |
title = "**Linear predictor** <span style='font-size: 14px;'>logit(*µ*) in the model</span>", | |
subtitle = "posterior_linpred(\n ..., tibble(flipper_length_mm = 201))\n") + | |
theme_pred_dist() | |
p1a <- ggplot(beta_linpred_phi, aes(x = phi)) + | |
stat_halfeye(fill = colorspace::lighten(clrs[3], 0.3)) + | |
labs(x = "Log-scale precision parameter", y = NULL, | |
title = "**Precision parameter** <span style='font-size: 14px;'>log(*φ*) in the model</span>", | |
subtitle = 'posterior_linpred(\n ..., tibble(flipper_length_mm = 201),\n dpar = "phi")') + | |
theme_pred_dist() | |
p2 <- ggplot(beta_epred, aes(x = .epred)) + | |
stat_halfeye(fill = clrs[2]) + | |
labs(x = "Ratio of bill depth / bill length", y = NULL, | |
title = "**Expectation of the posterior** <span style='font-size: 14px;'>E[*y*] or *µ* in the model</span>", | |
subtitle = "posterior_epred(\n ..., tibble(flipper_length_mm = 201)) # or \nposterior_linpred(..., transform = TRUE)") + | |
theme_pred_dist() | |
p2a <- ggplot(beta_linpred_phi_trans, aes(x = phi)) + | |
stat_halfeye(fill = colorspace::lighten(clrs[2], 0.4)) + | |
labs(x = "Precision parameter", y = NULL, | |
title = "**Precision parameter** <span style='font-size: 14px;'>*φ* in the model</span>", | |
subtitle = 'posterior_linpred(\n ..., tibble(flipper_length_mm = 201),\n dpar = "phi", transform = TRUE)\n') + | |
theme_pred_dist() | |
p3 <- ggplot(beta_predicted, aes(x = .prediction)) + | |
stat_halfeye(fill = clrs[1]) + | |
coord_cartesian(xlim = c(0.2, 0.6)) + | |
labs(x = "Ratio of bill depth / bill length", y = NULL, | |
title = "**Posterior predictions** <span style='font-size: 14px;'>Random draws from posterior Beta(*µ*, *φ*)</span>", | |
subtitle = "posterior_predict()") + | |
theme_pred_dist() | |
layout <- " | |
AB | |
CC | |
DE | |
FF | |
GG | |
" | |
p1 + p1a + plot_spacer() + p2 + p2a + plot_spacer() + p3 + | |
plot_layout(design = layout, heights = c(0.3, 0.05, 0.3, 0.05, 0.3)) | |
# Posterior predictions for these different parameters across a range of flipper lengths | |
p1 <- penguins |> | |
data_grid(flipper_length_mm = seq_range(flipper_length_mm, n = 100)) |> | |
add_linpred_draws(model_beta, ndraws = 100) |> | |
ggplot(aes(x = flipper_length_mm)) + | |
geom_point(data = penguins, aes(y = qlogis(bill_ratio)), size = 1, alpha = 0.7) + | |
stat_lineribbon(aes(y = .linpred), .width = 0.95, | |
alpha = 0.5, color = clrs[3], fill = clrs[3]) + | |
coord_cartesian(xlim = c(170, 230)) + | |
labs(x = "Flipper length (mm)", y = "Logit-scale ratio of\nbill depth / bill length", | |
title = "**Linear predictor posterior** <span style='font-size: 14px;'>logit(*µ*) in the model</span>", | |
subtitle = "posterior_linpred()") + | |
theme_pred_range() | |
p1a <- penguins |> | |
data_grid(flipper_length_mm = seq_range(flipper_length_mm, n = 100)) |> | |
add_linpred_draws(model_beta, ndraws = 100, dpar = "phi") |> | |
ggplot(aes(x = flipper_length_mm)) + | |
stat_lineribbon(aes(y = phi), .width = 0.95, alpha = 0.5, | |
color = colorspace::lighten(clrs[3], 0.3), fill = colorspace::lighten(clrs[3], 0.3)) + | |
coord_cartesian(xlim = c(170, 230)) + | |
labs(x = "Flipper length (mm)", y = "Log-scale\nprecision parameter", | |
title = "**Precision parameter** <span style='font-size: 14px;'>log(*φ*) in the model</span>", | |
subtitle = 'posterior_linpred(dpar = "phi")') + | |
theme_pred_range() | |
p2 <- penguins |> | |
data_grid(flipper_length_mm = seq_range(flipper_length_mm, n = 100)) |> | |
add_epred_draws(model_beta, ndraws = 100) |> | |
ggplot(aes(x = flipper_length_mm)) + | |
geom_point(data = penguins, aes(y = bill_ratio), size = 1, alpha = 0.7) + | |
stat_lineribbon(aes(y = .epred), .width = 0.95, | |
alpha = 0.5, color = clrs[2], fill = clrs[2]) + | |
coord_cartesian(xlim = c(170, 230)) + | |
labs(x = "Flipper length (mm)", y = "Ratio of\nbill depth / bill length", | |
title = "**Expectation of the posterior** <span style='font-size: 14px;'>E[*y*] or *µ* in the model</span>", | |
subtitle = 'posterior_epred()\nposterior_linpred(transform = TRUE)') + | |
theme_pred_range() | |
p2a <- penguins |> | |
data_grid(flipper_length_mm = seq_range(flipper_length_mm, n = 100)) |> | |
add_epred_draws(model_beta, ndraws = 100, dpar = "phi") |> | |
ggplot(aes(x = flipper_length_mm)) + | |
stat_lineribbon(aes(y = phi), .width = 0.95, alpha = 0.5, | |
color = colorspace::lighten(clrs[2], 0.4), fill = colorspace::lighten(clrs[2], 0.4)) + | |
coord_cartesian(xlim = c(170, 230)) + | |
labs(x = "Flipper length (mm)", y = "Precision parameter", | |
title = "**Precision parameter** <span style='font-size: 14px;'>*φ* in the model</span>", | |
subtitle = 'posterior_linpred(dpar = "phi",\n transform = TRUE)') + | |
theme_pred_range() | |
p3 <- penguins |> | |
data_grid(flipper_length_mm = seq_range(flipper_length_mm, n = 100)) |> | |
add_predicted_draws(model_beta, ndraws = 500) |> | |
ggplot(aes(x = flipper_length_mm)) + | |
geom_point(data = penguins, aes(y = bill_ratio), size = 1, alpha = 0.7) + | |
stat_lineribbon(aes(y = .prediction), .width = 0.95, | |
alpha = 0.5, color = clrs[1], fill = clrs[1]) + | |
coord_cartesian(xlim = c(170, 230)) + | |
labs(x = "Flipper length (mm)", y = "Ratio of\nbill depth / bill length", | |
title = "**Posterior predictions** <span style='font-size: 14px;'>Random draws from posterior Beta(*µ*, *φ*)</span>", | |
subtitle = "posterior_predict()") + | |
theme_pred_range() | |
layout <- " | |
AB | |
CC | |
DE | |
FF | |
GG | |
" | |
p1 + p1a + plot_spacer() + p2 + p2a + plot_spacer() + p3 + | |
plot_layout(design = layout, heights = c(0.3, 0.05, 0.3, 0.05, 0.3)) | |
# Playing with posterior beta parameters | |
mu <- summary_beta_epred$mean | |
phi <- summary_beta_linpred_phi_trans$mean | |
ggplot(penguins, aes(x = bill_ratio)) + | |
geom_density(aes(fill = "Actual data"), color = NA) + | |
stat_function( | |
aes(fill = glue::glue("Beta(µ = {round(mu, 3)}, φ = {round(phi, 2)})")), | |
geom = "area", fun = ~ extraDistr::dprop(., mean = mu, size = phi), | |
alpha = 0.7 | |
) + | |
scale_fill_manual(values = c(clrs[5], clrs[1]), name = NULL) + | |
xlim(c(0.2, 0.65)) + | |
labs(x = "Ratio of bill depth / bill length", y = NULL, | |
title = "**Analytical posterior predictions** <span style='font-size: 14px;'>Average posterior *µ* and *φ* from the model</span>") + | |
theme_pred_dist() + | |
theme(legend.position = c(0, 0.9), | |
legend.justification = "left", | |
legend.key.size = unit(0.75, "lines")) | |
muphi_to_shapes <- function(mu, phi) { | |
shape1 <- mu * phi | |
shape2 <- (1 - mu) * phi | |
return(lst(shape1 = shape1, shape2 = shape2)) | |
} | |
beta_posteriors <- tibble(flipper_length_mm = c(180, 200, 220)) |> | |
add_linpred_draws(model_beta, ndraws = 500, dpar = TRUE, transform = TRUE) |> | |
group_by(flipper_length_mm) |> | |
summarize(across(c(mu, phi), ~mean(.))) |> | |
ungroup() |> | |
mutate(shapes = map2(mu, phi, ~as_tibble(muphi_to_shapes(.x, .y)))) |> | |
unnest(shapes) |> | |
mutate(nice_label = glue::glue("Beta(µ = {round(mu, 3)}, φ = {round(phi, 2)})")) | |
# Here are the parameters we'll use | |
# We need to convert the mu and phi values to shape1 and shape2 so that we can | |
# use dist_beta() to plot the halfeye distributions correctly | |
beta_posteriors | |
## # A tibble: 3 × 6 | |
## flipper_length_mm mu phi shape1 shape2 nice_label | |
## <dbl> <dbl> <dbl> <dbl> <dbl> <glue> | |
## 1 180 0.485 57.3 27.8 29.5 Beta(µ = 0.485, φ = 57.29) | |
## 2 200 0.400 104. 41.5 62.4 Beta(µ = 0.4, φ = 103.92) | |
## 3 220 0.320 191. 61.3 130. Beta(µ = 0.32, φ = 191.31) | |
penguins |> | |
data_grid(flipper_length_mm = seq_range(flipper_length_mm, n = 100)) |> | |
add_predicted_draws(model_beta, ndraws = 500) |> | |
ggplot(aes(x = flipper_length_mm)) + | |
geom_point(data = penguins, aes(y = bill_ratio), size = 1, alpha = 0.7) + | |
stat_halfeye(data = beta_posteriors, aes(ydist = dist_beta(shape1, shape2), y = NULL), | |
side = "bottom", fill = clrs[1], alpha = 0.75) + | |
stat_lineribbon(aes(y = .prediction), .width = 0.95, | |
alpha = 0.1, color = clrs[1], fill = clrs[1]) + | |
geom_text(data = beta_posteriors, | |
aes(x = flipper_length_mm, y = 0.9, label = nice_label), | |
hjust = 0.5) + | |
coord_cartesian(xlim = c(170, 230)) + | |
labs(x = "Flipper length (mm)", y = "Ratio of\nbill depth / bill length", | |
title = "**Analytical posterior predictions** <span style='font-size: 14px;'>Average posterior *µ* and *φ* from the model</span>") + | |
theme_pred_range() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment