Skip to content

Instantly share code, notes, and snippets.

@tjmahr
Last active May 13, 2024 13:58
Show Gist options
  • Save tjmahr/00c31f55d39726c3133a79c192b52318 to your computer and use it in GitHub Desktop.
Save tjmahr/00c31f55d39726c3133a79c192b52318 to your computer and use it in GitHub Desktop.
using rvar to compute marginal means from a mixed effects model

Local .Rprofile detected at C:\Users\mahr\Documents\WiscRepos\2021-04-kh-td-intel-and-rate\.Rprofile

library(faux)
#> 
#> ************
#> Welcome to faux. For support and examples visit:
#> https://debruine.github.io/faux/
#> - Get and set global package options with: faux_options()
#> ************
library(tidyverse)

set.seed(20240510)

# thank you https://debruine.github.io/faux/articles/sim_mixed.html#simulating-data

# Create fake dataset
subj_n <- 40   # number of subjects
trials <- 5    # number of trials per cond
b0 <- 1        # intercept
b1 <- .5       # fixed effect of condition
u0s_sd <- .5   # random intercept SD
u1s_sd <- .2   # random b1 slope SD
r01s <- .2     # correlation between random effects 0 and 1

data <- add_random(subj = subj_n, item = trials) %>%
  # add and recode categorical variables
  add_within("subj", cond = c("control", "test")) %>%
  add_recode("cond", "cond_t", control = 0, test = 1) %>%
  # add random effects
  add_ranef("subj", u0s = u0s_sd, u1s = u1s_sd, .cors = r01s) %>%
  # calculate DV
  mutate(
    lin_dv = b0 + u0s + (b1 + u1s) * cond_t,
    obs_dv = rbinom(n(), 1, plogis(lin_dv))
  )

# ggplot(data) +
#   aes(x = cond) +
#   geom_point(aes(y = obs_dv)) +
#   stat_summary(aes( y = obs_dv, group = subj))
#
# data |>
#   group_by(subj, cond) |>
#   summarise(
#     m = mean(obs_dv)
#   ) |>
#   ggplot(aes(x = cond, y = m)) +
#   geom_line(aes(group = subj))

# Aggregate and use binomial instead of bernoulli
data_fast <- data |>
  group_by(subj, cond) |>
  summarise(
    n_success = sum(obs_dv),
    trials = n(),
    .groups = "drop"
  )


library(brms)
#> Loading required package: Rcpp
#> Loading 'brms' package (version 2.21.0). Useful instructions
#> can be found by typing help('brms'). A more detailed introduction
#> to the package is available through vignette('brms_overview').
#> 
#> Attaching package: 'brms'
#> The following object is masked from 'package:stats':
#> 
#>     ar
f <- bf(
  n_success | trials(trials) ~ cond + (1 + cond | subj),
  family = binomial()
)

p <- c(
  prior(normal(0, 1), class = "b"),
  prior(normal(0, 2), class = "sd"),
  prior(lkj(2), class = "cor")
)

validate_prior(p, f, data_fast)
#>                 prior     class      coef group resp dpar nlpar lb ub
#>          normal(0, 1)         b                                      
#>          normal(0, 1)         b  condtest                            
#>  student_t(3, 0, 2.5) Intercept                                      
#>  lkj_corr_cholesky(2)         L                                      
#>  lkj_corr_cholesky(2)         L            subj                      
#>          normal(0, 2)        sd                                  0   
#>          normal(0, 2)        sd            subj                  0   
#>          normal(0, 2)        sd  condtest  subj                  0   
#>          normal(0, 2)        sd Intercept  subj                  0   
#>        source
#>          user
#>  (vectorized)
#>       default
#>          user
#>  (vectorized)
#>          user
#>  (vectorized)
#>  (vectorized)
#>  (vectorized)

m <- brm(
  f,
  data = data_fast,
  prior = p,
  file = "test-model",
  backend = "cmdstanr"
)

library(posterior)
#> This is posterior version 1.5.0
#> 
#> Attaching package: 'posterior'
#> The following objects are masked from 'package:stats':
#> 
#>     mad, sd, var
#> The following objects are masked from 'package:base':
#> 
#>     %in%, match

inv_logit <- function(x) UseMethod("inv_logit")
inv_logit.default <- function(x) stats::plogis(x)
inv_logit.rvar <- function(x) rvar(stats::plogis(draws_of(x)))

# Simulate for real model parameters
real_mean <- c(b0, b1)
real_cov <- lme4::sdcor2cov(matrix(c(u0s_sd, r01s, r01s, u1s_sd), ncol = 2))

real_marginals <- rvar(mvtnorm::rmvnorm(
  1000,
  mean = real_mean,
  sigma = real_cov
))
real_marginals[2] <- real_marginals[2] + real_marginals[1]
real_marginals2 <- rvar(rnorm_multi(
  1000,
  mu = real_mean,
  sd = c(u0s_sd, u1s_sd),
  r = matrix(c(1, r01s, r01s, 1), ncol = 2),
  as.matrix = TRUE
))
real_marginals2[2] <- real_marginals2[2] + real_marginals2[1]




# Marginal means approach 1:

# Make predictions for the population mean on model scale and draw 
# many new participants on each draw to average over
# random effects

posterior_cov <- m |> VarCorr(summary = FALSE) |> _$subj$cov |> rvar()

# the multivariate distribution is over the effects/contrasts so we need the 
# fixed effects
posterior_mean <- m |> fixef(summary = FALSE) |> rvar()

new_subj <- rdo(
  mvtnorm::rmvnorm(1000, mean = posterior_mean, sigma = posterior_cov)
)
new_subj[, 2] <- new_subj[, 1] + new_subj[, 2]


# alternative approach bc the random effects have mean 0 so we can add them to 
# predicted values
one_subject <- data_fast |>
  filter(subj == "subj01") |>
  mutate(subj = "fake")

posterior_mean_alt <- posterior_linpred(
  m,
  newdata = one_subject,
  re_formula = NA,
  allow_new_levels = TRUE
) |>
  # want 1 column per condition
  rvar(dim = c(2, 1))


new_subj_alt <- rdo(
  mvtnorm::rmvnorm(1000, mean = posterior_mean_alt, sigma = posterior_cov)
)


# Approach 1b: Different RNG function
posterior_sd <- VarCorr(m, summary = FALSE) |> _$subj$sd |> rvar()
posterior_cor <- VarCorr(m, summary = FALSE) |> _$subj$cor |> rvar()

new_subj2 <- rdo(
  faux::rnorm_multi(
    1000,
    mu = posterior_mean,
    sd = posterior_sd,
    r = posterior_cor,
    as.matrix = TRUE)
)
new_subj2[, 2] <- new_subj2[, 1] + new_subj2[, 2]

# Approach 2: Simulate a new sample of data with new participants
data_fake_sample <- data_fast |>
  mutate(subj = paste0("fake", "subj"))

# Approach 2a: Reuse preexisting random effects. We are kind of doing a
# posterior bootstrapping thing
posterior_epred <- posterior_epred(
  m,
  newdata = data_fake_sample,
  allow_new_levels = TRUE,
  sample_new_levels = "uncertainty"
)
posterior_epred <- posterior_epred |>
  rvar(dim = c(2, 40)) |>
  t()

# Approach 2b: Simulate from Gaussian distribution. This is like what
# we did above but with n = nrow() instead of n = 1000
posterior_epred2 <- posterior_epred(
  m,
  newdata = data_fake_sample,
  allow_new_levels = TRUE,
  sample_new_levels = "gaussian"
) |>
  rvar(dim = c(2, 40)) |>
  t()


rbind(
  real_mtvnorm = real_marginals |> inv_logit() |> mean(),
  real_rnorm_multi = real_marginals2 |> inv_logit() |> mean(),
  mvtnorm = new_subj |> inv_logit() |> rvar_apply(2, rvar_mean),
  mvtnorm_alt = new_subj_alt |> inv_logit() |> rvar_apply(2, rvar_mean),
  rnorm_multi = new_subj2 |> inv_logit() |> rvar_apply(2, rvar_mean),
  # bigger standard errors
  epred_default = (posterior_epred / 5) |> rvar_apply(2, rvar_mean),
  epred_gaussian = (posterior_epred2 / 5) |> rvar_apply(2, rvar_mean)
)
#> rvar<4000>[7,2] mean ± sd:
#>                  [,1]          [,2]         
#> real_mtvnorm     0.73 ± 0.000  0.81 ± 0.000 
#> real_rnorm_multi 0.72 ± 0.000  0.80 ± 0.000 
#> mvtnorm          0.73 ± 0.033  0.80 ± 0.032 
#> mvtnorm_alt      0.73 ± 0.033  0.81 ± 0.032 
#> rnorm_multi      0.73 ± 0.033  0.80 ± 0.032 
#> epred_default    0.73 ± 0.079  0.80 ± 0.094 
#> epred_gaussian   0.73 ± 0.084  0.80 ± 0.104

Created on 2024-05-13 with reprex v2.1.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment