Skip to content

Instantly share code, notes, and snippets.

@matteocourthoud
Created July 11, 2022 08:29
Show Gist options
  • Save matteocourthoud/7fe7e8760d9a01c5ab89d69949c0e10d to your computer and use it in GitHub Desktop.
Save matteocourthoud/7fe7e8760d9a01c5ab89d69949c0e10d to your computer and use it in GitHub Desktop.
from sklearn.neighbors import KNeighborsRegressor
from sklearn.linear_model import LogisticRegressionCV
def X_learner(df, model, y, D, X):
temp = dgp.generate_data(true_te=True).sort_values(X)
# Mu
mu0 = model.fit(temp.loc[temp[D]==0, X], temp.loc[temp[D]==0, y])
temp['mu0_hat_'] = mu0.predict(temp[X])
mu1 = model.fit(temp.loc[temp[D]==1, X], temp.loc[temp[D]==1, y])
temp['mu1_hat_'] = mu1.predict(temp[X])
# Y
y0 = KNeighborsRegressor(n_neighbors=1).fit(temp.loc[temp[D]==0, X], temp.loc[temp[D]==0, y])
temp['y0_hat'] = y0.predict(temp[X])
y1 = KNeighborsRegressor(n_neighbors=1).fit(temp.loc[temp[D]==1, X], temp.loc[temp[D]==1, y])
temp['y1_hat'] = y1.predict(temp[X])
# Weight
e = LogisticRegressionCV().fit(y=temp[D], X=temp[X]).predict_proba(temp[X])[:,1]
temp['mu0_hat'] = e * temp['y0_hat'] + (1-e) * temp['mu0_hat_']
temp['mu1_hat'] = (1-e) * temp['y1_hat'] + e * temp['mu1_hat_']
# Plot
plot_TE(temp, true_te=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment