Forked from romainmartinez/sensitivity_analysis_example.py
Created
July 25, 2021 14:55
-
-
Save jsnouffer/5fb41948b236e8c8fd6ac1df65cf3705 to your computer and use it in GitHub Desktop.
Sensitivity analysis of a (scikit-learn) machine learning model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.datasets import make_regression | |
import pandas as pd | |
from xgboost import XGBRegressor | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
X, y = make_regression(n_samples=500, n_features=4, n_informative=2, noise=0.3) | |
X = pd.DataFrame(X, columns=['A', 'B', 'C', 'D']) | |
model = XGBRegressor() | |
model.fit(X, y) | |
class Simulate: | |
def __init__(self, obs, var): | |
self.obs = obs | |
self.var = var | |
def simulate_increase(self, model, percentage): | |
baseline = model.predict(self.obs) | |
plus = {} | |
for ivar in self.var: | |
X_plus = self.obs.copy() | |
X_plus[ivar] = X_plus[ivar] + X_plus[ivar] * (percentage / 100) | |
plus[ivar] = model.predict(X_plus) | |
b = pd.DataFrame( | |
plus, index=['simulated' | |
]).T.reset_index().rename(columns={'index': 'test'}) | |
b['baseline'] = baseline[0] | |
return b | |
@staticmethod | |
def plot_simulation(d, **kwargs): | |
fig, ax = plt.subplots() | |
sns.barplot(x='test', y='simulated', data=d, palette='deep', ax=ax) | |
ax.axhline(d['baseline'].values[0], color='grey', linestyle='--', linewidth=2) | |
ax.plot([0, 0], [-100, -100], color='grey', linestyle='--', linewidth=2, label='baseline') | |
maxi = int(d['simulated'].max() + d['simulated'].max() * 0.1) | |
mini = int(d['simulated'].min() - d['simulated'].min() * 0.1) | |
ax.set_ylim([mini, maxi]) | |
ax.set_xlabel('Simulated variables') | |
ax.set_ylabel('Target value') | |
ax.set_title(kwargs.get('title')) | |
ax.legend() | |
ax.grid(axis='y', linewidth=.3) | |
sns.despine(offset=10, trim=True) | |
plt.tight_layout() | |
plt.show() | |
VAR_OPTIMIZE = ['A', 'B', 'C'] | |
PERC = 5 | |
ROW = X.iloc[[29]] | |
S = Simulate(obs=ROW, var=VAR_OPTIMIZE) | |
d = S.simulate_increase(model=model, percentage=PERC) | |
S.plot_simulation(d, title=f'Impact of a {PERC}% increase of {VAR_OPTIMIZE} in target value') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment