Skip to content

Instantly share code, notes, and snippets.

@miraculixx
Last active November 7, 2024 23:45
Show Gist options
  • Save miraculixx/5408d437b66a7dc99c4d300c61d91a46 to your computer and use it in GitHub Desktop.
Save miraculixx/5408d437b66a7dc99c4d300c61d91a46 to your computer and use it in GitHub Desktop.
ML drift statistics calculator and plots

DriftStats - A ML Drift Metric Calculator & Plotter

Calculating drift metrics for machine learning model is seemingly straight forward, yet surprisingly laborious and complex in practice. This library makes it simple and straight forward.

Why?

Most texts on calculting model drift focus on some specific metric to calculate, like Jensen-Shannon Distance or Chisquare. Many times there are some examples for single-variable datasets, explaining all the mathemetical details. That's great to learn abou the topic.

However, in practice, we have datasets with many features, of different types. Calculting one metric for one feature is one thing, calculating many metrics for many features and many datasets is quiet another.

Automation is needed. That's what this library provides.

How it works

In a nutshell, driftstats takes a baseline and a target dataframe and calculates multiple drift metrics for all columns, like PSI, KS, JSD, Chi2, etc. It normalizes each metric to a score between 0 and 1 and calculates a boolean drift indicator. As a result we get a dataframe of drift metrics. Predefined plot functions help us plot both the drift metrics and the baseline and target distributions of any one feature.

How to install

  1. Download all files
  2. pip install -r requirements.txt
  3. In Jupyter Lab open the driftstats NB for a tutorial

How to use

from driftstats import DriftStatistics

baseline = np.random.normal(0, 1, 1000)
target = np.random.normal(0.5, 1, 1000)
target2 = np.random.normal(0.5, 1.5, 1000)

calc = DriftStatistics()
drifts = calc.compare(baseline, target)
calc.plot_drift(drifts)
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
from scipy.stats import chi2_contingency, entropy, ks_2samp, wasserstein_distance
class DriftStatistics:
# created with the help of duck.ai / GPT4o
def __init__(self):
pass
def psi(self, baseline, target):
"""Calculate Population Stability Index (PSI)"""
def calculate_psi(expected, actual, bins=10):
hist_expected, _ = np.histogram(expected, bins=bins, density=True)
hist_actual, _ = np.histogram(actual, bins=bins, density=True)
hist_expected += 1e-10 # Avoid division by zero
hist_actual += 1e-10
psi_value = np.sum(
(hist_expected - hist_actual) * np.log(hist_expected / hist_actual)
)
return psi_value
psi_value = calculate_psi(baseline, target)
drift_detected = psi_value > 0.25 # Threshold for drift
score = min(1, psi_value / 0.5) # Normalize score to 0-1
return {"drift": drift_detected, "metric": psi_value, "score": score}
def kl_divergence(self, baseline, target):
"""Calculate Kullback-Leibler Divergence (KL)"""
hist_baseline, _ = np.histogram(baseline, bins=100, density=True)
hist_target, _ = np.histogram(target, bins=100, density=True)
hist_baseline += 1e-10 # Avoid division by zero
hist_target += 1e-10
kl_value = entropy(hist_baseline, hist_target)
drift_detected = kl_value > 0.1 # Threshold for drift
score = min(1, kl_value / 1) # Normalize score to 0-1
return {"drift": drift_detected, "metric": kl_value, "score": score}
def js_divergence(self, baseline, target):
"""Calculate Jensen-Shannon Divergence (JS)"""
hist_baseline, _ = np.histogram(baseline, bins=100, density=True)
hist_target, _ = np.histogram(target, bins=100, density=True)
hist_baseline += 1e-10 # Avoid division by zero
hist_target += 1e-10
m = 0.5 * (hist_baseline + hist_target)
js_value = 0.5 * (entropy(hist_baseline, m) + entropy(hist_target, m))
drift_detected = js_value > 0.05 # Threshold for drift
score = min(1, js_value / 0.2) # Normalize score to 0-1
return {"drift": drift_detected, "metric": js_value, "score": score}
def ks_test(self, baseline, target):
"""Perform Kolmogorov-Smirnov Test (KS)"""
ks_stat, p_value = ks_2samp(baseline, target)
drift_detected = p_value < 0.05 # Drift detected if p-value < 0.05
score = min(1, ks_stat) # KS statistic is between 0 and 1
return {"drift": drift_detected, "metric": ks_stat, "score": score}
def wasserstein(self, baseline, target):
"""Calculate Wasserstein Distance"""
w_distance = wasserstein_distance(baseline, target)
drift_detected = w_distance > 0.1 # Threshold for drift
score = min(1, w_distance / 0.5) # Normalize score to 0-1
return {"drift": drift_detected, "metric": w_distance, "score": score}
def chi_squared_test(self, baseline, target):
"""Perform Chi-squared test for categorical data."""
baseline_counts = baseline.value_counts(normalize=True)
target_counts = target.value_counts(normalize=True)
# Create a DataFrame to align the categories
all_categories = baseline_counts.index.union(target_counts.index)
baseline_counts = baseline_counts.reindex(all_categories, fill_value=0)
target_counts = target_counts.reindex(all_categories, fill_value=0)
# Create a contingency table
contingency_table = np.array(
[[baseline_counts[cat], target_counts[cat]] for cat in all_categories]
)
chi2, p_value, _, _ = chi2_contingency(contingency_table)
drift_detected = p_value < 0.05 # Drift detected if p-value < 0.05
score = min(1, chi2 / 10) # Normalize score to 0-1, adjust as needed
return {"drift": drift_detected, "metric": chi2, "score": score}
def compute_statistics_for_dataframe(self, baseline_df, target_df):
"""Compute all statistics for each column in the DataFrames."""
results = {}
# Identify numeric columns
numeric_columns = baseline_df.select_dtypes(include=np.number).columns.tolist()
for col in baseline_df.columns:
if col in target_df.columns:
baseline = baseline_df[col].values
target = target_df[col].values
if col in numeric_columns:
stats = {
"PSI": self.psi(baseline, target),
"KL Divergence": self.kl_divergence(baseline, target),
"Jensen-Shannon Divergence": self.js_divergence(
baseline, target
),
"KS Test": self.ks_test(baseline, target),
"Wasserstein Distance": self.wasserstein(baseline, target),
}
else: # Treat all other columns as categorical
stats = {
"Chi-squared Test": self.chi_squared_test(baseline, target)
}
results[col] = stats
return results
def compare(self, baseline_df, target_df, best=False):
"""Compare two DataFrames and return a dict with drifted columns and their highest scores."""
drift_results = {}
# Identify numeric columns
numeric_columns = baseline_df.select_dtypes(include=np.number).columns.tolist()
for col in baseline_df.columns:
if col in target_df.columns:
baseline = baseline_df[col]
target = target_df[col]
# Collect all statistics
stats = {}
if col in numeric_columns:
stats["PSI"] = self.psi(baseline, target)
stats["KL Divergence"] = self.kl_divergence(baseline, target)
stats["Jensen-Shannon Divergence"] = self.js_divergence(
baseline, target
)
stats["KS Test"] = self.ks_test(baseline, target)
stats["Wasserstein Distance"] = self.wasserstein(baseline, target)
else: # Treat all other columns as categorical
stats["Chi-squared Test"] = self.chi_squared_test(baseline, target)
# Find the highest score and associated statistic
if stats and best:
highest_score_stat = max(
stats.items(), key=lambda item: item[1]["score"]
)
if highest_score_stat[1]["drift"]:
drift_results[col] = {
highest_score_stat[0]: highest_score_stat[1]
}
else:
drift_results[col] = stats
return pd.DataFrame(
[
{
"col": col,
"statistic": k,
"drift": v["drift"],
"metric": v["metric"],
"score": v["score"],
}
for col in drift_results
for k, v in drift_results[col].items()
]
)
def plot_drift(
self,
drift_results,
col=None,
statistic=None,
x="statistic",
y="metric",
color="col",
facet_col="col",
query=None,
**plot_kwargs
):
"""Plot drift statistics for each column that detected drift."""
flt = (drift_results["col"] == col if col else drift_results.index >= 0) & (
drift_results["statistic"] == statistic
if statistic
else drift_results.index >= 0
)
dfx = drift_results[flt]
dfx = dfx.query(query) if query else dfx
fig = px.bar(dfx, x=x, y=y, color=color, facet_col=facet_col, **plot_kwargs)
return fig
def plot_hist(self, df1, df2, column="pop", query=None, **plot_kwargs):
"""
Plots a combined histogram of the specified column from two DataFrames.
Parameters:
df1 (pd.DataFrame): The first DataFrame.
df2 (pd.DataFrame): The second DataFrame.
column (str): The column name to plot. Default is 'pop'.
"""
# Add a source label to each DataFrame
df1, df2 = df1.copy(), df2.copy()
df1.loc[:, "source"] = "baseline"
df2.loc[:, "source"] = "target"
# Combine the two DataFrames
combined_df = pd.concat([df1, df2], ignore_index=True)
dfx = combined_df.query(query) if query else combined_df
# Create the histogram
fig = px.histogram(
dfx,
x=column,
color="source",
barmode="overlay",
title="Combined Histogram of " + column,
**plot_kwargs
)
fig.show()
# Example usage:
# baseline_data1 = np.random.normal(0, 1, 1000)
# target_data1 = np.random.normal(0.5, 1, 1000)
# target_data2 = np.random.normal(0.5, 1.5, 1000)
# df1 = pd.DataFrame({'feature_1': baseline_data1, 'feature_2': target_data1, 'feature_3': np.random.choice(['A', 'B', 'C'], size=1000)})
# df2 = pd.DataFrame({'feature_1': baseline_data1, 'feature_2': target_data2, 'feature_3': np.random.choice(['A', 'B', 'D'], size=1000)})
# drift_detector = DriftStatistics()
# drift_results = drift_detector.compare(df1, df2)
# print(drift_results)
# drift_detector.plot_drift_statistics(drift_results)
Copyright (c) 2024 miraculixx
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
pandas
plotly
matplotlib
scipy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment