Last active
January 17, 2023 19:26
-
-
Save romainmartinez/d1aa798896d2f8cde62e40a3e59ec4a5 to your computer and use it in GitHub Desktop.
Sensitivity analysis of a (scikit-learn) machine learning model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.datasets import make_regression | |
import pandas as pd | |
from xgboost import XGBRegressor | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
X, y = make_regression(n_samples=500, n_features=4, n_informative=2, noise=0.3) | |
X = pd.DataFrame(X, columns=['A', 'B', 'C', 'D']) | |
model = XGBRegressor() | |
model.fit(X, y) | |
class Simulate: | |
def __init__(self, obs, var): | |
self.obs = obs | |
self.var = var | |
def simulate_increase(self, model, percentage): | |
baseline = model.predict(self.obs) | |
plus = {} | |
for ivar in self.var: | |
X_plus = self.obs.copy() | |
X_plus[ivar] = X_plus[ivar] + X_plus[ivar] * (percentage / 100) | |
plus[ivar] = model.predict(X_plus) | |
b = pd.DataFrame( | |
plus, index=['simulated' | |
]).T.reset_index().rename(columns={'index': 'test'}) | |
b['baseline'] = baseline[0] | |
return b | |
@staticmethod | |
def plot_simulation(d, **kwargs): | |
fig, ax = plt.subplots() | |
sns.barplot(x='test', y='simulated', data=d, palette='deep', ax=ax) | |
ax.axhline(d['baseline'].values[0], color='grey', linestyle='--', linewidth=2) | |
ax.plot([0, 0], [-100, -100], color='grey', linestyle='--', linewidth=2, label='baseline') | |
maxi = int(d['simulated'].max() + d['simulated'].max() * 0.1) | |
mini = int(d['simulated'].min() - d['simulated'].min() * 0.1) | |
ax.set_ylim([mini, maxi]) | |
ax.set_xlabel('Simulated variables') | |
ax.set_ylabel('Target value') | |
ax.set_title(kwargs.get('title')) | |
ax.legend() | |
ax.grid(axis='y', linewidth=.3) | |
sns.despine(offset=10, trim=True) | |
plt.tight_layout() | |
plt.show() | |
VAR_OPTIMIZE = ['A', 'B', 'C'] | |
PERC = 5 | |
ROW = X.iloc[[29]] | |
S = Simulate(obs=ROW, var=VAR_OPTIMIZE) | |
d = S.simulate_increase(model=model, percentage=PERC) | |
S.plot_simulation(d, title=f'Impact of a {PERC}% increase of {VAR_OPTIMIZE} in target value') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Please also include these 2 in import.
import matplotlib.pyplot as plt
import seaborn as sns