Created
May 18, 2021 09:38
-
-
Save blu3r4y/8b9b262d413efbe15b3ae463aa9490ad to your computer and use it in GitHub Desktop.
Grouped train test split used by Dynatrace - SAL - LIT.AI.JKU in the NAD 2021 challenge
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2021 | |
# Dynatrace Research | |
# SAL Silicon Austria Labs | |
# LIT Artificial Intelligence Lab | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
from typing import Union | |
from itertools import chain | |
import numpy as np | |
import pandas as pd | |
from sklearn.model_selection import train_test_split | |
log = logging.getLogger(__name__) | |
def grouped_train_test_split(*arrays, | |
groups: Union[np.ndarray, pd.Series], | |
labels: Union[np.ndarray, pd.Series], | |
test_size: float = 0.25, random_state: int = None, | |
max_reshuffle: int = 100, test_size_eps: float = 0.2): | |
""" | |
Very similar to `sklearn.model_selection.train_test_split` (even has the same syntax), | |
but will actually split the `groups`, and remap them, resulting in a split that will always keep | |
the same groups together in one fold. | |
Outputs are similar to sklearn, so use like this: | |
>>> X = np.arange(100) # some feature vector | |
>>> y = np.random.randint(2, size=100) # two randomly assigned classes [0, 1] | |
>>> g = np.random.randint(10, size=100) # ten randomly assigned groups | |
>>> | |
>>> X_train, X_test, y_train, y_test = grouped_train_test_split(X, y, groups=g, labels=y, test_size=0.25)) | |
:param arrays: Allowed inputs are numpy arrays or pandas series. | |
:param groups: A one-dimensional array or series of the same length as the inputs that holds group labels, | |
which will be used to always keep the groups together within each split. | |
:param labels: A one-dimensional array or series of the same length as the inputs that holds class labels, | |
which will be used to preserve the class label distribution in the train and test splits. | |
:param test_size: The size of the test split as number between 0 and 1. (default: 0.25) | |
:param random_state: Some integer for deterministic sampling. (default: random integer) | |
:param max_reshuffle: In order to get the desired splitting ratio we will try a few different. | |
random splits until we get it right (we try multiple times, because only after the re-mapping | |
we know the true split ratio). This parameter indicates how often we will try that. (default: 100) | |
:param test_size_eps: The tolerance that we allow between `test_size` and the true test size. (default: 0.2) | |
""" | |
if random_state is None: | |
random_state = np.random.randint(1_000_000) | |
# shapes must match | |
assert all([len(arr) == len(groups) for arr in arrays]) | |
assert all([len(arr) == len(labels) for arr in arrays]) | |
groups = pd.Series(groups) | |
labels = pd.Series(labels) | |
nclasses = labels.nunique() | |
# compute the median label within each group (just as a quick'n'dirty majority vote) | |
groups_and_labels = pd.DataFrame({"group": groups, "stratum": labels}) | |
majorities = groups_and_labels.groupby("group").median().astype(int) | |
# this number shall be close to `test_size` eventually | |
true_test_size = np.full(nclasses, -1) | |
train_mask, test_mask = None, None | |
# loop until we stay within the desired bounds | |
# or reach the reshuffle limit ... | |
nshuffle = 0 | |
while np.any(np.abs(true_test_size - test_size) > test_size_eps) \ | |
and nshuffle <= max_reshuffle: | |
nshuffle += 1 | |
# first, split the group labels and try to stratify a little bit | |
train_groups, test_groups = train_test_split(majorities.index, test_size=test_size, | |
stratify=majorities["stratum"], | |
random_state=random_state + nshuffle) | |
# remap the masks | |
train_mask, test_mask = groups.isin(train_groups), groups.isin(test_groups) | |
# compute resulting label distribution | |
train_distr = labels[train_mask].value_counts() | |
test_distr = labels[test_mask].value_counts() | |
# preserve the number of classes in each split | |
splits_have_all_classes = train_distr.nunique() == test_distr.nunique() == nclasses | |
# check the difference in the distributions | |
# (as long as we at least got all classes in each split as well) | |
if splits_have_all_classes: | |
true_test_size = test_distr / (test_distr + train_distr) | |
else: | |
true_test_size = np.full(nclasses, -1) | |
# TODO: we could save the closest split we make instead of the last one here ... | |
# warn if we still couldn't achieve the desired split | |
if np.any(np.abs(true_test_size - test_size) > test_size_eps): | |
log.warning(f"the true test sizes of the grouped split are still {true_test_size} " | |
f"after {max_reshuffle} tries to re-shuffle the groups " | |
f"(you wanted {test_size:.2f} +/- {test_size_eps:.2f})") | |
else: | |
log.info(f"achieved true test sizes: {true_test_size}") | |
return list(chain.from_iterable( | |
(arr[train_mask], arr[test_mask]) | |
for arr in arrays | |
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
numpy>=1.19 | |
pandas>=1.2 | |
scikit-learn>=0.23 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment