Created
December 20, 2023 01:45
-
-
Save apoorvalal/b8bbb2527d82cf358b1c87b349facf2d to your computer and use it in GitHub Desktop.
covariate adjustment using nonparametric regression (Wager et al 2016 PNAS)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
from scipy.stats import norm | |
from sklearn.model_selection import cross_val_predict, KFold | |
# learners | |
from xgboost import XGBRegressor | |
from glum import GeneralizedLinearRegressorCV | |
from sklearn.kernel_ridge import KernelRidge | |
# %% | |
class RegAdjustment: | |
def __init__(self, model, nodesize=20, conf_level=0.95): | |
self.model = model | |
self.nodesize = nodesize | |
self.conf_level = conf_level | |
self._critv = norm.ppf(1 - (1 - self.conf_level) / 2) | |
def fit(self, X, Y, W, nfolds = 2, hdlin = True): | |
if not hasattr(self.model, "fit") or not hasattr(self.model, "predict"): | |
raise ValueError("Model must have 'fit' and 'predict' methods.") | |
if ( | |
not isinstance(X, np.ndarray) | |
or not isinstance(Y, np.ndarray) | |
or not isinstance(W, np.ndarray) | |
): | |
raise ValueError("X, Y, and W must be numpy arrays.") | |
if np.prod(np.isin(W, [0, 1])) != 1: | |
raise ValueError("Treatment assignment W must be encoded as a 0-1 vector.") | |
n = X.shape[0] | |
n0, n1 = (1-W).sum(), W.sum() | |
self.yhat_0, self.yhat_1 = np.empty(n), np.empty(n) | |
cv = KFold(n_splits = nfolds, shuffle = True, random_state = 42) | |
for train, test in cv.split(X, Y, W): | |
# model fits | |
id_0, id_1 = W[train] == 0, W[train] == 1 | |
μ0mod = self.model.fit(X[train][id_0], Y[train][id_0]) | |
μ1mod = self.model.fit(X[train][id_1], Y[train][id_1]) | |
self.yhat_0[test] = μ0mod.predict(X[test,:]) | |
self.yhat_1[test] = μ1mod.predict(X[test,:]) | |
if hdlin: | |
# scalar Ȳ_1 - Ȳ_0 | |
yhat_bar = (np.sum(W == 1)/n * self.yhat_0 + | |
np.sum(W == 0)/n * self.yhat_1) | |
# sec 5 of Wager et al (eq 23) | |
self.tau_hat = np.mean((Y - yhat_bar)[W == 1]) - np.mean((Y - yhat_bar)[W == 0]) | |
# variance | |
self.var_hat = (np.var((Y - yhat_bar)[W == 1], ddof=1) / np.sum(W == 1) + | |
np.var((Y - yhat_bar)[W == 0], ddof=1) / np.sum(W == 0) | |
) | |
else: # yields same answer, but for completeness - formula 23 | |
# τ̂ = μ̂1(X) - μ̂0(X) + 1/n1 ∑_{W == 1} Y_i - μ̂1(X) + 1/n0 ∑_{W == 0} Y_i - μ̂0(X) | |
self.tau_hat = ( | |
np.mean(self.yhat_1 - self.yhat_0) + | |
np.mean(Y[W == 1] - self.yhat_1[W == 1]) - | |
np.mean(Y[W == 0] - self.yhat_1[W == 0]) | |
) | |
self.var_hat = np.sum( | |
1/n1 * np.var(Y[W == 1] - self.yhat_1[W == 1], ddof = 1) + | |
1/n0 * np.var(Y[W == 0] - self.yhat_1[W == 0], ddof = 1) | |
) | |
self.se = np.sqrt(self.var_hat) | |
self.ci = [ | |
self.tau_hat - self._critv * self.se, | |
self.tau_hat + self._critv * self.se, | |
] | |
return self | |
def summary(self): | |
return { | |
"tau": self.tau_hat, | |
"se" : self.se, | |
"conf_int": self.ci, | |
"conf_level": self.conf_level, | |
} | |
# %% lalonde data | |
import empirical_calibration as ec | |
from formulaic import Formula | |
treated = ec.data.lalonde.experimental_treated() | |
control = ec.data.lalonde.experimental_control() | |
df = pd.concat([treated, control]) | |
df["unemployed1974"] = np.where(df["earnings1974"] == 0, 1, 0) | |
df["unemployed1975"] = np.where(df["earnings1975"] == 0, 1, 0) | |
fml = """age + education + black + hispanic + married + nodegree + earnings1974 + | |
earnings1975 + unemployed1974 + unemployed1975""" | |
X = Formula(fml).get_model_matrix(df).values | |
y, w = df.earnings1978.values, df.treatment.values | |
# %% | |
import statsmodels.formula.api as smf | |
smf.ols(f"earnings1978 ~ {'+'.join(set(df.columns) - {'earnings1978'})}", | |
data = df).fit(cov_type = "HC1").summary().tables[1] | |
# coef std err | |
# treatment 1670.7088 680.201 | |
# %% lasso | |
mod = RegAdjustment(GeneralizedLinearRegressorCV()) | |
mod.fit(X, y, w) | |
mod.summary() | |
# {'tau': 1583.4302398269924, | |
# 'se': 697.1604092961959, | |
# 'conf_int': [217.02094615924534, 2949.8395334947395], | |
# 'conf_level': 0.95} | |
# %% kernel ridge - p good | |
mod2 = RegAdjustment(KernelRidge(alpha=0.1, kernel='rbf', gamma=1.0)) | |
mod2.fit(X, y, w, hdlin = False) | |
mod2.summary() | |
# {'tau': 1688.2209460671193, | |
# 'se': 763.8963108129432, | |
# 'conf_int': [191.01168895073556, 3185.430203183503], | |
# 'conf_level': 0.95} | |
# %% boosting - needs tuning / too little data | |
mod3 = RegAdjustment(XGBRegressor()) | |
mod3.fit(X, y, w, hdlin = False) | |
mod3.summary() | |
# {'tau': 743.8672561200384, | |
# 'se': 853.245299813594, | |
# 'conf_int': [-928.4628014926863, 2416.197313732763], | |
# 'conf_level': 0.95} | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment