Created
April 16, 2023 16:13
-
-
Save ivallesp/1137a456af22001f644edfcb21e9110b to your computer and use it in GitHub Desktop.
Example of simple test for covariate shift
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
import scipy as sp | |
from sklearn.model_selection import cross_val_score | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.utils import shuffle | |
from scipy.stats import ttest_1samp | |
np.random.seed(655321) | |
# Prepare a toy dataset | |
x_train = pd.read_csv("https://download.mlcc.google.com/mledu-datasets/california_housing_train.csv") | |
x_test = pd.read_csv("https://download.mlcc.google.com/mledu-datasets/california_housing_test.csv") | |
x = pd.concat([x_train, x_test]) # Concatenate train and test sets | |
x = x.drop("median_house_value", axis=1) # Drop target variable | |
y = np.arange(len(x)) >= len(x_train) # Create the is_test target variable | |
x, y = shuffle(x, y) | |
# Determine the AUC(ROC) score at classifying each instance as train or test. | |
# Values >> 0.5 imply that covariate shift is likely a problem. Values close to | |
# 0.5 mean that there may not be covariate shift. | |
folds_aucs = cross_val_score(RandomForestClassifier(), x, y, scoring="roc_auc", cv=10, n_jobs=-1) | |
p_value = ttest_1samp(folds_aucs, 0.5).pvalue | |
print(f"AUC(ROC) = {folds_aucs.mean():.02f} ± {folds_aucs.std():.02f}") | |
print(f"p-value (H0: AUC(ROC) mean is 0.5): {p_value:.02}") | |
# Determine the importance of variables at classifying train and test. | |
m = RandomForestClassifier(n_jobs=-1).fit(x, y) | |
df_importances = ( | |
pd.DataFrame({"variable": x.columns, "importance": m.feature_importances_}) | |
.sort_values(by="importance", ascending=False) | |
) | |
print("Importance of variables:") | |
print(df_importances.to_string(index=False)) | |
# _________________________________________________ | |
# AUC(ROC) = 0.51 ± 0.02 | |
# p-value (H0: AUC(ROC) mean is 0.5): 0.34 | |
# Importance of Variables | |
# variable importance | |
# median_income 0.137577 | |
# population 0.137418 | |
# total_rooms 0.135345 | |
# total_bedrooms 0.126998 | |
# households 0.126550 | |
# longitude 0.122192 | |
# latitude 0.118497 | |
# housing_median_age 0.095424 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment