Last active
October 15, 2021 02:05
-
-
Save BioSciEconomist/6a2953e639393d924fc7529f5508d5c2 to your computer and use it in GitHub Desktop.
Simulate data where SHAP values are not causal
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# *----------------------------------------------------------------- | |
# | PROGRAM NAME: beyond SHAP.py | |
# | DATE: 10/14/21 | |
# | CREATED BY: MATT BOGARD | |
# | PROJECT FILE: | |
# *---------------------------------------------------------------- | |
# | PURPOSE: simulate SHAP values that are not causal | |
# *---------------------------------------------------------------- | |
# this code is based on: Be Careful When Interpreting Predictive Models in Search of Causal Insights | |
# by Scot Lundberg see: https://towardsdatascience.com/be-careful-when-interpreting-predictive-models-in-search-of-causal-insights-e68626e664b6 | |
# Original Code: https://shap.readthedocs.io/en/latest/example_notebooks/overviews/Be%20careful%20when%20interpreting%20predictive%20models%20in%20search%20of%20causal%C2%A0insights.html | |
# see also: https://towardsdatascience.com/explain-your-model-with-the-shap-values-bc36aac4de3d | |
# https://towardsdatascience.com/shap-explained-the-way-i-wish-someone-explained-it-to-me-ab81cc69ef30 | |
import numpy as np | |
import pandas as pd | |
import scipy.stats | |
import sklearn | |
import xgboost | |
import econml | |
import shap | |
# | |
# generate data | |
# | |
class FixableDataFrame(pd.DataFrame): | |
""" Helper class for manipulating generative models. | |
""" | |
def __init__(self, *args, fixed={}, **kwargs): | |
self.__dict__["__fixed_var_dictionary"] = fixed | |
super(FixableDataFrame, self).__init__(*args, **kwargs) | |
def __setitem__(self, key, value): | |
out = super(FixableDataFrame, self).__setitem__(key, value) | |
if isinstance(key, str) and key in self.__dict__["__fixed_var_dictionary"]: | |
out = super(FixableDataFrame, self).__setitem__(key, self.__dict__["__fixed_var_dictionary"][key]) | |
return out | |
# generate the data | |
def generator(n, fixed={}, seed=0): | |
""" The generative model for our subscriber retention example. | |
""" | |
if seed is not None: | |
np.random.seed(seed) | |
X = FixableDataFrame(fixed=fixed) | |
# the number of sales calls made to this customer | |
X["Sales calls"] = np.random.uniform(0, 4, size=(n,)).round() | |
# the number of sales calls made to this customer | |
X["Interactions"] = X["Sales calls"] + np.random.poisson(0.2, size=(n,)) | |
# the health of the regional economy this customer is a part of | |
X["Economy"] = np.random.uniform(0, 1, size=(n,)) | |
# the time since the last product upgrade when this customer came up for renewal | |
X["Last upgrade"] = np.random.uniform(0, 20, size=(n,)) | |
# how much the user perceives that they need the product | |
X["Product need"] = (X["Sales calls"] * 0.1 + np.random.normal(0, 1, size=(n,))) | |
# the fractional discount offered to this customer upon renewal | |
X["Discount"] = ((1-scipy.special.expit(X["Product need"])) * 0.5 + 0.5 * np.random.uniform(0, 1, size=(n,))) / 2 | |
# What percent of the days in the last period was the user actively using the product | |
X["Monthly usage"] = scipy.special.expit(X["Product need"] * 0.3 + np.random.normal(0, 1, size=(n,))) | |
# how much ad money we spent per user targeted at this user (or a group this user is in) | |
X["Ad spend"] = X["Monthly usage"] * np.random.uniform(0.99, 0.9, size=(n,)) + (X["Last upgrade"] < 1) + (X["Last upgrade"] < 2) | |
# how many bugs did this user encounter in the since their last renewal | |
X["Bugs faced"] = np.array([np.random.poisson(v*2) for v in X["Monthly usage"]]) | |
# how many bugs did the user report? | |
X["Bugs reported"] = (X["Bugs faced"] * scipy.special.expit(X["Product need"])).round() | |
# did the user renew? | |
X["Did renew"] = scipy.special.expit(7 * ( | |
0.18 * X["Product need"] \ | |
+ 0.08 * X["Monthly usage"] \ | |
+ 0.1 * X["Economy"] \ | |
+ 0.05 * X["Discount"] \ | |
+ 0.05 * np.random.normal(0, 1, size=(n,)) \ | |
+ 0.05 * (1 - X['Bugs faced'] / 20) \ | |
+ 0.005 * X["Sales calls"] \ | |
+ 0.015 * X["Interactions"] \ | |
+ 0.1 / (X["Last upgrade"]/4 + 0.25) | |
+ X["Ad spend"] * 0.0 - 0.45 | |
)) | |
# in real life we would make a random draw to get either 0 or 1 for if the | |
# customer did or did not renew. but here we leave the label as the probability | |
# so that we can get less noise in our plots. Uncomment this line to get | |
# noiser causal effect lines but the same basic results | |
X["Did renew"] = scipy.stats.bernoulli.rvs(X["Did renew"]) | |
return X | |
def user_retention_dataset(): | |
""" The observed data for model training. | |
""" | |
n = 10000 | |
X_full = generator(n) | |
y = X_full["Did renew"] | |
X = X_full.drop(["Did renew", "Product need", "Bugs faced"], axis=1) | |
return X, y | |
# | |
# fit xgboost model | |
# | |
def fit_xgboost(X, y): | |
""" Train an XGBoost model with early stopping. | |
""" | |
X_train,X_test,y_train,y_test = sklearn.model_selection.train_test_split(X, y) | |
dtrain = xgboost.DMatrix(X_train, label=y_train) | |
dtest = xgboost.DMatrix(X_test, label=y_test) | |
model = xgboost.train( | |
{ "eta": 0.001, "subsample": 0.5, "max_depth": 2, "objective": "reg:logistic"}, dtrain, num_boost_round=200000, | |
evals=((dtest, "test"),), early_stopping_rounds=20, verbose_eval=False | |
) | |
return model | |
X, y = user_retention_dataset() # define data | |
model = fit_xgboost(X, y) # fit model | |
# calculate SHAP values | |
explainer = shap.Explainer(model) | |
shap_values = explainer(X) | |
# plot SHAP values | |
clust = shap.utils.hclust(X, y, linkage="complete") | |
shap.plots.bar(shap_values, clustering=clust, clustering_cutoff=1) | |
# summary plot | |
shap.summary_plot(shap_values, X) | |
# show that SHAP gets correlations directionally wrong (based on theoretically simulated values) | |
shap.plots.scatter(shap_values[:,7]) # bugs reported | |
shap.plots.scatter(shap_values[:,4]) # discount | |
shap.plots.scatter(shap_values[:,6]) # ad spend | |
shap.plots.scatter(shap_values[:,2]) # economy |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment