Skip to content

Instantly share code, notes, and snippets.

@vincentarelbundock
Last active November 9, 2022 03:16
Show Gist options
  • Save vincentarelbundock/937f01d3dd20674e30492d4ae3f851c0 to your computer and use it in GitHub Desktop.
Save vincentarelbundock/937f01d3dd20674e30492d4ae3f851c0 to your computer and use it in GitHub Desktop.
# Example of examining a continuous x categorical interaction using emmeans,
# and an attempt at doing the same using marginaleffects.
# Author: Cameron Patrick <[email protected]>
library(tidyverse)
library(emmeans)
library(marginaleffects)
# use the mtcars data, set up am as a factor
data(mtcars)
mc <- mtcars %>%
mutate(am = factor(am))
# fit a linear model to mpg with wt x am interaction
m <- lm(mpg ~ wt*am, data = mc)
summary(m)
### Some questions we may want to answer from this model, using emmeans.
# 1. means for each level of am at mean wt.
emmeans(m, "am")
marginalmeans(m, variables = "am")
predictions(m, newdata = datagrid(am = 0:1))
# 2. means for each level of am at wt = 2.5, 3, 3.5.
emmeans(m, c("am", "wt"), at = list(wt = c(2.5, 3, 3.5)))
predictions(m, newdata = datagrid(am = 0:1, wt = c(2.5, 3, 3.5))
# 3. means for wt = 2.5, 3, 3.5, averaged over levels of am (implicitly!).
emmeans(m, "wt", at = list(wt = c(2.5, 3, 3.5)))
# same thing, but the averaging is more explicit, using the `by` argument
predictions(
m,
newdata = datagrid(am = 0:1, wt = c(2.5, 3, 3.5)),
by = "wt")
# 4. graphical version of 2.
emmip(m, am ~ wt, at = list(wt = c(2.5, 3, 3.5)), CIs = TRUE)
plot_cap(m, condition = c("wt", "am"))
# 5. compare levels of am at specific values of wt.
# this is a bit ugly because the emmeans defaults for pairs() are silly.
# infer = TRUE: enable confidence intervals.
# adjust = "none": begone, Tukey.
# reverse = TRUE: contrasts as (later level) - (earlier level)
pairs(emmeans(m, "am", by = "wt", at = list(wt = c(2.5, 3, 3.5))),
infer = TRUE, adjust = "none", reverse = TRUE)
comparisons(
m,
variables = "am",
newdata = datagrid(wt = c(2.5, 3, 3.5)))
# 6. plot of pairswise comparisons
plot(pairs(emmeans(m, "am", by = "wt", at = list(wt = c(2.5, 3, 3.5))),
infer = TRUE, adjust = "none", reverse = TRUE))
# Since `wt` is numeric, the default is to plot it as a continuous variable on
# the x-axis. But not that this is the **exact same info** as in the emmeans plot.
plot_cco(m, effect = "am", condition = "wt")
# You of course customize everything, set draw=FALSE, and feed the raw data to feed to ggplot2
p <- plot_cco(
m,
effect = "am",
condition = list(wt = c(2.5, 3, 3.5)),
draw = FALSE)
ggplot(p, aes(y = condition1, x = comparison, xmin = conf.low, xmax = conf.high)) +
geom_pointrange()
# 7. slope of wt for each level of am
emtrends(m, "am", "wt")
marginaleffects(m, newdata = datagrid(am = 0:1))
### Now try to do the same using marginaleffects
# 1. okay this seems easy
marginalmeans(m, "am")
# plots of means are nice, but no way to turn them sideways?
plot_cap(m, "am")
# These are standard ggplot2 objects, so all the ggplot2 functions work out-of-the-box.
plot_cap(m, "am") + ggplot2::coord_flip()
# 2. but output is slightly ugly compared to emmeans/marginalmeans
# if you supply a function, it is applied automatically. Instead of levels(mc$am), just unique.
predictions(
m,
newdata = datagrid(
am = levels(mc$am),
wt = c(2.5, 3, 3.5)
)
)
predictions(
m,
newdata = datagrid(
am = unique,
wt = c(2.5, 3, 3.5)
)
)
# 3.
predictions(
m,
newdata = datagrid(
am = levels(mc$am),
wt = c(2.5, 3, 3.5)
),
by = "wt"
)
# 4. no way to get plots of 2 easily using plot_cap()?
# Again, those are standard ggplot2 objects, so you can easily combine them with {patchwork} or {cowplot} or whatever you are used to.
# 5. compare levels of am at specific values of wt.
# this kinda works but get a warning about 'am' being a character (it's not
# though, it's a factor in the original data frame) and all contrasts appear
# twice for some reason??
# Contrasts appear twice, because you created a grid with 6 rows: all
# combinations of am in 0:1 and the three elements of wt. Look at the last two
# columns of output. You can see this by calling just datagrid()
datagrid(model = m, am = unique, wt = c(2.5, 3, 3.5))
# Since your grid has 6 rows, `marginaleffects` will compute the slope at each
# of those points in `newdata`, so you get 6 estimates, including some
# duplicates (that you yourself created and requested)
# Alternative 1: omit am from datagrid() altogether.
marginaleffects(
m,
newdata = datagrid(wt = c(2.5, 3, 3.5)),
variables = "am")
# Alternative 2: include it, but then use `by` to marginalize
marginaleffects(
m,
newdata = datagrid(am = unique, wt = c(2.5, 3, 3.5)),
by = "wt",
variables = "am")
# 6. no way to plot this easily?
plot_cme(m, effect = "am", condition = "wt")
# 7. slope of wt for each level of am.
# this feels like what marginaleffects was really designed to do but also
# somehow way more complicated than I expected?
# simpler code:
marginaleffects(
m,
variables = "wt",
newdata = datagrid(am = unique))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment