Created
October 17, 2022 09:33
-
-
Save yptheangel/64ee5ffe067b32825a5b1e7c5c655d16 to your computer and use it in GitHub Desktop.
shap_streamlit_xgb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import shap | |
import streamlit as st | |
import streamlit.components.v1 as components | |
import xgboost | |
import matplotlib.pyplot as plt | |
@st.cache | |
def load_data(): | |
return shap.datasets.boston() | |
def st_shap(plot, height=None): | |
shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>" | |
components.html(shap_html, height=height) | |
st.title("SHAP in Streamlit") | |
# train XGBoost model | |
X,y = load_data() | |
model = xgboost.train({"learning_rate": 0.01}, xgboost.DMatrix(X, label=y), 100) | |
# explain the model's predictions using SHAP | |
# (same syntax works for LightGBM, CatBoost, scikit-learn and spark models) | |
explainer = shap.TreeExplainer(model) | |
shap_values = explainer.shap_values(X) | |
# actual plotting | |
st_shap(shap.force_plot(explainer.expected_value, shap_values[0,:], X.iloc[0,:])) | |
st.set_option('deprecation.showPyplotGlobalUse', False) | |
shap.summary_plot(shap_values, X) | |
st.pyplot(bbox_inches='tight') | |
plt.clf() | |
# shap_object = shap.Explanation(base_values = shap_values[0][0].base_values, | |
# values = shap_values[0].values, | |
# feature_names = X.columns, | |
# data = shap_values[0].data) | |
# class ShapObject: | |
# def __init__(self, base_values, data, values, feature_names): | |
# self.base_values = base_values # Single value | |
# self.data = data # Raw feature values for 1 row of data | |
# self.values = values # SHAP values for the same row of data | |
# self.feature_names = feature_names # Column names | |
# row = 10 | |
# shap_object = ShapObject(base_values = explainer.expected_value[1], | |
# values = explainer.shap_values(X)[1][row,:], | |
# feature_names = X.columns, | |
# data = X.iloc[row,:]) | |
# shap_object = ShapObject(base_values = shap_values[0][0].base_values, | |
# values = shap_values[0].values, | |
# feature_names = X.columns, | |
# data = shap_values[0].data) | |
# shap.plots.waterfall_plot(shap_object) | |
# shap.plots.waterfall(shap_object) | |
# shap.plots.waterfall(shap_values[0]) | |
# shap.plots._waterfall.waterfall_legacy(explainer.expected_value, shap_values) | |
# st.pyplot(bbox_inches='tight') | |
# plt.clf() | |
# visualize the training set predictions | |
st_shap(shap.force_plot(explainer.expected_value, shap_values, X), 400) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment