Created
November 10, 2022 11:53
-
-
Save gongcastro/0450bb0058282ad23229997b1def4aad to your computer and use it in GitHub Desktop.
ROC curves for multinomial and binomial Bayesian models in brms
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(dplyr) # for data wrangling | |
library(tidyr) # same | |
library(purrr) # for functional programming | |
library(rlang) # for tidyeval | |
library(ggplot2) # for dataviz | |
library(ggsci) # for nice colours | |
library(scales) # for displaying percentages | |
library(brms) # for Bayesian models | |
library(tidybayes) # for extracting posterior draws and predictions | |
library(yardstick) # for generating ROC curves | |
# you might need to install cmdstanr too | |
# set options ------------------------------------------------------------------ | |
options(mc.cores = 4, brms.backend = "cmdstanr") # for faster compilation and sampling | |
set.seed(888) # for reproducibility | |
theme_set(theme_ggdist()) # change ggplot theme | |
# create functions ------------------------------------------------------------- | |
# generate mean posterior predictions and ROC values | |
get_roc_curve <- function(newdata, object, ...) { | |
# enquote response variable and get brmsfit family | |
resp_var <- formula(object)[["formula"]][[2]] | |
resp_var <- enquo(resp_var) | |
model_fam <- object[["family"]][["family"]] | |
# object must be a brmsfit object with a supported family | |
supported <- c("bernoulli", "binomial", "categorical", "cumulative", "sratio", "cratio", "acat") | |
stopifnot(is.brmsfit(object)) | |
if (!(model_fam %in% supported)) stop(paste0("model family must be one of: ", paste0(supported, collapse = ", "))) | |
if (model_fam %in% c("binomial", "bernoulli")) { | |
roc_values <- add_epred_draws(newdata, object, ...) %>% | |
ungroup() %>% | |
mutate(!!resp_var := as.factor(!!resp_var)) %>% | |
# generate a ROC curve for each posterior draw | |
split(.$.draw) %>% | |
map_dfr(~roc_curve(., truth = !!resp_var, .epred, event_level = "second"), .id = ".draw") | |
} else { | |
cat_symbols <- syms(as.character(unique(get_y(object)))) | |
roc_values <- add_epred_draws(newdata, object, ...) %>% | |
ungroup() %>% | |
mutate(!!resp_var := as.factor(!!resp_var)) %>% | |
# spread predictions for different categories across different columns | |
pivot_wider(names_from = .category, values_from = .epred) %>% | |
# generate a ROC curve for each posterior draw | |
split(.$.draw) %>% | |
map_dfr(~roc_curve(., truth = !!resp_var, !!!cat_symbols), .id = ".draw") | |
} | |
return(roc_values) | |
} | |
# single ROC ------------------------------------------------------------------- | |
# fit cumulative(logit) model | |
fit <- brm( | |
rating ~ treat + period + (1 | subject), | |
data = brms::inhaler, | |
family = cumulative("logit"), | |
chains = 4 | |
) | |
roc <- get_roc_curve(brms::inhaler, fit, ndraws = 50) | |
ggplot(roc, aes(1-specificity, sensitivity, colour = .level)) + | |
facet_wrap(~.level, labeller = labeller(.level = ~paste("Category", .))) + | |
geom_line(aes(y = 1-specificity), linetype = "dotted", colour = "black") + | |
geom_line(aes(group = interaction(.draw)), alpha = 0.1) + | |
stat_summary(fun = mean, geom = "line", size = 1) + # mean posterior prediction | |
scale_colour_d3() + | |
scale_x_continuous(labels = percent) + | |
scale_y_continuous(labels = percent) + | |
coord_equal() + | |
labs(x = "1- Specificity", y = "Sensibility", colour = "Model") + | |
theme_ggdist() + | |
theme(legend.position = "none") | |
ggsave("img/rocs-single.png", height = 7, width = 9, dpi = 800) | |
# multinomial ROC -------------------------------------------------------------- | |
# helper function for getting brms model family | |
get_family <- function(x) paste0(x[["family"]][["family"]], "(", x[["family"]][["link"]], ")") | |
# wrapper for fitting models | |
fit_model <- function(...) brm(rating ~ treat + period + (1 | subject), chains = 4, ...) | |
# fit multinomial models | |
multinomial_fits <- list(cumulative("logit"), sratio("logit"), cratio("logit"), categorical("logit")) %>% | |
map(fit_model, data = brms::inhaler) %>% | |
set_names(map_chr(., get_family)) | |
roc_multinomial <- multinomial_fits %>% | |
map(~get_roc_curve(brms::inhaler, ., ndraws = 50)) %>% | |
bind_rows(.id = "model") | |
roc_multinomial %>% | |
ggplot(aes(1-specificity, sensitivity, colour = .level)) + | |
facet_grid(model~.level, labeller = labeller(.level = ~paste("Category", .))) + | |
geom_line(aes(y = 1-specificity), linetype = "dotted", colour = "black") + | |
geom_line(aes(group = interaction(model, .draw)), alpha = 0.1) + | |
stat_summary(fun = mean, geom = "line", size = 1) + # mean posterior prediction | |
scale_colour_d3() + | |
scale_x_continuous(labels = percent) + | |
scale_y_continuous(labels = percent) + | |
coord_equal() + | |
labs(x = "1- Specificity", y = "Sensibility", colour = "Model") + | |
theme_ggdist() + | |
theme( | |
axis.text = element_text(size = 7), | |
legend.position = "none", | |
) | |
ggsave("rocs-multinomial.png", height = 7, width = 9, dpi = 800) | |
# binomial ROC ----------------------------------------------------------------- | |
# fit binomial models | |
binomial_fits <- list(bernoulli("logit"), bernoulli("probit"), bernoulli("cloglog"), bernoulli("cauchit")) %>% | |
# fit binomial models on dicotomised rating (TRUE if rating==1) | |
map(fit_model, data = mutate(brms::inhaler, rating = as.integer(rating==1))) %>% | |
set_names(map_chr(., get_family)) # name list elements with their model family | |
roc_binomial <- roc_multinomial <- binomial_fits %>% | |
map(~get_roc_curve(mutate(brms::inhaler, rating = as.integer(rating==1)), ., ndraws = 50)) %>% | |
bind_rows(.id = "model") | |
roc_binomial %>% | |
ggplot(aes(1-specificity, sensitivity, colour = model)) + | |
facet_wrap(~model, labeller = labeller(.level = ~paste("Category", .))) + | |
geom_line(aes(y = 1-specificity), linetype = "dotted", colour = "black") + | |
geom_line(aes(group = interaction(model, .draw)), alpha = 0.1) + | |
stat_summary(fun = mean, geom = "line", size = 1) + | |
scale_colour_d3() + | |
scale_x_continuous(labels = percent) + | |
scale_y_continuous(labels = percent) + | |
coord_equal() + | |
labs(x = "1- Specificity", y = "Sensibility", colour = "Category") + | |
theme_ggdist() + | |
theme( | |
axis.text = element_text(size = 9) | |
) | |
ggsave("rocs-binomial.png", height = 7, width = 9, dpi = 800) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Continuing the convo in Mastodon: 🐘 https://fediscience.org/@gongcastro/109353350340945974.