Skip to content

Instantly share code, notes, and snippets.

@simrit1
Forked from matteocourthoud/meta_X_learner.py
Created July 14, 2022 23:23
Show Gist options
  • Save simrit1/a1fe2e85369491ecb82703f0cf2cde60 to your computer and use it in GitHub Desktop.
Save simrit1/a1fe2e85369491ecb82703f0cf2cde60 to your computer and use it in GitHub Desktop.
from sklearn.neighbors import KNeighborsRegressor
from sklearn.linear_model import LogisticRegressionCV
def X_learner(df, model, y, D, X):
temp = dgp.generate_data(true_te=True).sort_values(X)
# Mu
mu0 = model.fit(temp.loc[temp[D]==0, X], temp.loc[temp[D]==0, y])
temp['mu0_hat_'] = mu0.predict(temp[X])
mu1 = model.fit(temp.loc[temp[D]==1, X], temp.loc[temp[D]==1, y])
temp['mu1_hat_'] = mu1.predict(temp[X])
# Y
y0 = KNeighborsRegressor(n_neighbors=1).fit(temp.loc[temp[D]==0, X], temp.loc[temp[D]==0, y])
temp['y0_hat'] = y0.predict(temp[X])
y1 = KNeighborsRegressor(n_neighbors=1).fit(temp.loc[temp[D]==1, X], temp.loc[temp[D]==1, y])
temp['y1_hat'] = y1.predict(temp[X])
# Weight
e = LogisticRegressionCV().fit(y=temp[D], X=temp[X]).predict_proba(temp[X])[:,1]
temp['mu0_hat'] = e * temp['y0_hat'] + (1-e) * temp['mu0_hat_']
temp['mu1_hat'] = (1-e) * temp['y1_hat'] + e * temp['mu1_hat_']
# Plot
plot_TE(temp, true_te=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment