Created
December 23, 2022 16:28
-
-
Save gongcastro/60b358a1cff657113dad343a870664d9 to your computer and use it in GitHub Desktop.
Interpreting and playing with the coefficients of logit regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(ggplot2) | |
library(dplyr) | |
library(tidyr) | |
library(purrr) | |
library(scales) | |
library(patchwork) | |
theme_set(theme_minimal()) | |
# define functions ------------------------------------------------------------- | |
# logistic function with 2 parameters (intercept and slope) | |
# this function assumes 0 and 1 as limits of the logistic curve | |
logistic <- function(x, slope, mid) { | |
1/(1+exp(-slope*(x-mid))) | |
} | |
# get point in x at the inflection point (where y = 0.5) | |
get_mid <- function(x, predictor = "x") { | |
coef(x)["(Intercept)"]/-coef(x)[predictor] | |
} | |
# get point in x at the intercept of a line with slope beta/4 that passes through 0.5 | |
# this function assumes that lower and upper bounds of the logistic curve are 0 and 1 | |
# under these conditions, 0.5 will always be the inflection point of the logistic curve | |
get_mid_intercept <- function(x, predictor = "x") { | |
mid <- coef(x)["(Intercept)"]/-coef(x)[predictor] | |
slope <- coef(x)[predictor]/4 | |
(-slope*mid) + 0.5 | |
} | |
# simulate data ---------------------------------------------------------------- | |
n <- 100 # number of participants | |
x <- runif(n, 12, 35) # age of participants | |
x_std <- scale(x)[, 1] # age of participants (standardised) | |
probs <- logistic(x, 0.4, 29) # probability of comprehension (y = 1) increases with age | |
y <- rbinom(n = n, size = 1, prob = probs) # comprehension responses (y = 0 is "No", y = 1 is "Yes") | |
d <- tibble(x, x_std, y) | |
# fit model -------------------------------------------------------------------- | |
fit_x <- glm(y ~ x, data = d, family = binomial("logit")) | |
fit_x_std <- glm(y ~ x_std, data = d, family = binomial("logit")) | |
# model predictions (on the probability scale by default) | |
preds <- d %>% | |
mutate(pred_x = fitted(fit_x), | |
pred_x_std = fitted(fit_x_std)) | |
# predictions of model with unstandardised age | |
ggplot(preds, aes(x, y)) + | |
# plot observations | |
geom_point(data = d, shape = 1, size = 2, stroke = 1, | |
position = position_jitter(height = 0.1), | |
alpha = 0.8, colour = "grey") + | |
# plot predictions | |
geom_line(data = preds, aes(y = pred_x), size = 1) + | |
# plot intercept (y when x = 0) in the probability scale | |
geom_hline(yintercept = plogis(coef(fit_x)["(Intercept)"]), | |
linetype = "dotted", colour = "red") + | |
# plot mid point (x when y = 0.5) | |
geom_vline(xintercept = get_mid(fit_x), | |
linetype = "dotted", colour = "red") + | |
# plot slope when x = mid point (approximated derivative of the logistic curve) | |
geom_abline(slope = coef(fit_x)["x"]/4, | |
intercept = get_mid_intercept(fit_x), | |
colour = "blue") + | |
labs(x = "Predictor: age in months)", | |
y = "Response: P(comprehension)") + | |
# predictions of model with standardised age | |
ggplot(preds, aes(x_std, y)) + | |
geom_point(data = preds, shape = 1, size = 2, stroke = 1, | |
position = position_jitter(height = 0.1), | |
alpha = 0.8, colour = "grey") + | |
geom_line(data = preds, aes(x = x_std, y = pred_x_std), size = 1) + | |
# plot intercept (y when x = 0) | |
geom_hline(yintercept = plogis(coef(fit_x_std)["(Intercept)"]), | |
linetype = "dotted", colour = "red") + | |
# plot mid point (x when y = 0.5) | |
geom_vline(xintercept = get_mid(fit_x_std, "x_std"), | |
linetype = "dotted", colour = "red") + | |
# plot slope when x = mid point (approximated derivative of the logistic curve) | |
geom_abline(slope = coef(fit_x_std)["x_std"]/4, | |
intercept = get_mid_intercept(fit_x_std, "x_std"), | |
colour = "blue") + | |
labs(x = "Predictor: age (standardised))", | |
y = "Response: P(comprehension)") + | |
theme(axis.title.y = element_blank()) + | |
plot_layout(nrow = 1, guides = "collect") & | |
scale_y_continuous(labels = percent, breaks = seq(0, 1, 0.25)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment