Last active
May 15, 2025 18:40
-
-
Save albertbuchard/8c7fb0b8faae20f04ccc03f39794d2fb to your computer and use it in GitHub Desktop.
O3 Answers: Pre-weighting ≡ IPAW (in lmer)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This example demonstrates that fitting a weighted linear mixed-effects model using lmer(weights = …) | |
is algebraically equivalent to pre-multiplying the outcome and all predictors (including intercepts) | |
by the square root of the weights and fitting an unweighted model on the transformed data. | |
The test validates that this "pre-weight and refit" strategy yields identical fitted values and | |
is especially useful when packages like clubSandwich prohibit prior weights in robust variance estimation. | |
Includes realistic longitudinal synthetic data with fixed and random effects, dropout, and continuous covariates. | |
""" | |
from typing import List | |
import numpy as np | |
import pandas as pd | |
import pytest | |
import rpy2.robjects as ro | |
from rpy2.interactive.packages import importr | |
from rpy2.robjects import pandas2ri | |
pandas2ri.activate() | |
# ---------- helpers ---------------------------------------------------------- | |
def ensure_r_packages(pkgs: List[str]) -> None: | |
"""Install any missing R packages from CRAN.""" | |
utils = importr("utils") | |
utils.chooseCRANmirror(ind=1, graphics=False) | |
installed = set(ro.r('installed.packages()[,"Package"]')) | |
to_install = [p for p in pkgs if p not in installed] | |
if to_install: | |
utils.install_packages( | |
ro.StrVector(to_install), repos="https://cloud.r-project.org" | |
) | |
# realistic synthetic data --------------------------------------------------- | |
def _realistic_dataset(n_patients: int = 8, n_sessions: int = 5) -> pd.DataFrame: | |
rng = np.random.default_rng(2025) | |
rows = [] | |
for pid in range(1, n_patients + 1): | |
age = rng.integers(25, 50) | |
sex = rng.integers(0, 2) | |
last = rng.integers(n_sessions - 1, n_sessions + 1) | |
for sess in range(1, last + 1): | |
stress = rng.normal() | |
ysq = rng.normal() | |
y = ( | |
30 | |
- 0.6 * age | |
+ 3.5 * sex | |
- 1.2 * stress | |
+ 2.0 * ysq | |
+ 1.8 * sess | |
+ rng.normal(scale=2) | |
) | |
rows.append( | |
dict( | |
id=pid, | |
session=sess, | |
age=age, | |
sex=sex, | |
stress=stress, | |
ysq_all_schema_mean_score=ysq, | |
y=y, | |
) | |
) | |
df = pd.DataFrame(rows) | |
df["ipaw"] = rng.gamma(shape=2.0, scale=0.8, size=len(df)) | |
return df | |
# integration test ----------------------------------------------------------- | |
@pytest.mark.integration | |
def test_preweight_refit_equivalence_full(): | |
ensure_r_packages(["dplyr", "lme4"]) | |
df = _realistic_dataset() | |
ro.globalenv["dat"] = pandas2ri.py2rpy(df) | |
r_code = """ | |
suppressPackageStartupMessages({library(dplyr); library(lme4)}) | |
## weighted fit ------------------------------------------------- | |
fit_w <- lmer( | |
y ~ session + age + sex + stress + ysq_all_schema_mean_score + | |
(session | id), | |
data = dat, weights = ipaw, REML = FALSE | |
) | |
yhat_w <- fitted(fit_w) | |
## pre-weighted refit (explicit intercept) -------------------- | |
dat_pw <- dat %>% | |
mutate( | |
w_sqrt = sqrt(ipaw), | |
intercept_star = w_sqrt, | |
y_star = w_sqrt * y, | |
session_star = w_sqrt * session, | |
age_star = w_sqrt * age, | |
sex_star = w_sqrt * sex, | |
stress_star = w_sqrt * stress, | |
ysq_star = w_sqrt * ysq_all_schema_mean_score | |
) | |
fit_pw <- lmer( | |
y_star ~ -1 + intercept_star + session_star + age_star + | |
sex_star + stress_star + ysq_star + | |
(0 + intercept_star + session_star | id), | |
data = dat_pw, REML = FALSE | |
) | |
## compare fitted values -------------------------------------- | |
yhat_pw_back <- fitted(fit_pw) / dat_pw$w_sqrt | |
max_abs_diff <- max(abs(yhat_w - yhat_pw_back)) | |
""" | |
ro.r(r_code) | |
diff = float(ro.globalenv["max_abs_diff"][0]) | |
assert diff < 1e-8, f"max |Δ fitted| = {diff:.3e}" | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment