Last active
December 4, 2019 18:58
-
-
Save mattsgithub/8dfe22cf849dfd7d184c9a58f1043411 to your computer and use it in GitHub Desktop.
Prepares Pandas DataFrame for Counterfactual Predictions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_pandas_df_for_counterfactual_prediction(df, | |
df_treatments, | |
id_column='customer_id', | |
suffix='_assigned'): | |
"""Returns a pandas dataframe prepared for counterfactual predictions | |
Args | |
df (pd.DataFrame): | |
The dataframe from which to prepare counterfactuals for | |
df_treatments (pd.DataFrame) | |
A dataframe of treatments. In experiments, we might | |
assign several different variable values. This treatment | |
dataframe contains all possible assignments. | |
id_column (str): | |
The column name in df that repesents the unit we are doing | |
analysis on. | |
suffix (str): | |
Treatment columns in df will be renamed with this suffix. It's | |
useful in retaining the original treatment columns for further | |
analysis | |
Examples: | |
Suppose df is: | |
customer_id t1 t2 | |
1 1 0.1 | |
2 3 0.2 | |
Suppose df_treatments is: | |
t1 t2 | |
1 0.1 | |
1 0.2 | |
3 0.1 | |
3 0.2 | |
>>> get_pandas_df_for_counterfactual_prediction(df, df_treatments) | |
Returns: | |
customer_id t1_assigned t2_assigned t1 t2 | |
1 1 0.1 1 0.1 | |
1 1 0.1 1 0.2 | |
1 1 0.1 3 0.1 | |
1 1 0.1 3 0.2 | |
2 3 0.2 1 0.1 | |
2 3 0.2 1 0.2 | |
2 3 0.2 3 0.1 | |
2 3 0.2 3 0.2 | |
""" | |
intersecting_cols = df_treatments.columns.intersection(df.columns) | |
cols_not_found_in_df = df_treatments.columns.difference(intersecting_cols) | |
if len(cols_not_found_in_df) > 0: | |
raise ValueError('The columns {} in df_treatments ' | |
'are not found in df'.format(cols_not_found_in_df)) | |
n_rows = df.shape[0] | |
df = df.rename(columns={c: c + suffix for c in df_treatments.columns}) | |
# Generate a row for each counterfactual prediction | |
df = df.append([df] * (df_treatments.shape[0] - 1)) | |
df = df.sort_values(id_column) | |
df = df.reset_index(drop=True) | |
df_treatments = pd.concat([df_treatments] * n_rows) | |
df_treatments = df_treatments.reset_index(drop=True) | |
df = pd.concat([df, df_treatments], axis= 1) | |
return df |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
After having trained a heterogeneous treatment effects model, you will want to prepare a dataframe for counterfactual predictions by enumerating over all the possible combinations of treatments.