Last active
August 28, 2019 22:02
-
-
Save eric-czech/260de02335007250b06582226bbe73b7 to your computer and use it in GitHub Desktop.
(Python/Sklearn) Function to run stratified split into arbitrary number of label subsets with some desired proportions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from sklearn.model_selection import train_test_split | |
| import numpy as np | |
| def _get_stratified_split(splits, idx, values, proportions): | |
| # Return immediately if there are fewer than 2 proportions use in splits | |
| if len(proportions) <= 1: | |
| splits.append(idx) | |
| return splits | |
| # Determine proportion of items to extract for this split and associated | |
| # indexes for that proportion of the items (and add to final result) | |
| p = proportions[0] | |
| idx_split = train_test_split(idx, stratify=values[idx], train_size=p)[0] | |
| splits.append(idx_split) | |
| # Recurse with this split removed from indexes and proportions rescaled | |
| # to reflect the fraction of items in the subset equivalent to the desired | |
| # fraction in the original set | |
| idx = np.setdiff1d(idx, idx_split) | |
| proportions = (1./(1-p)) * proportions[1:] | |
| return _get_stratified_split(splits, idx, values, proportions) | |
| def get_stratified_split(values, proportions): | |
| """Split values into arbitrary sets with target size while using stratification | |
| Args: | |
| values: Labels to use for stratification | |
| proportions: Sequence containing desired proportions of resulting splits | |
| Returns: | |
| List of numpy arrays with length equal to length of proportions and values | |
| equivalent to indexes associated with each split (the number of these indexes | |
| should roughly account for the desired portion of samples and each split | |
| should be comprised of roughly equal label frequencies) | |
| Example: | |
| ``` | |
| # Binary labels array | |
| values = (np.arange(50) > 10).astype(int) | |
| # Split into 3 groups where first is largest at 60% of all elements (and next two are smaller) | |
| proportions = [.6, .3, .1] | |
| [np.unique(y[s], return_counts=True) for s in get_stratified_split(values, proportions)] | |
| >> [ | |
| (array([0, 1]), array([ 7, 23])), # 30 items (60% of 50), ~80/20 class balance | |
| (array([0, 1]), array([ 3, 12])), # 15 items (30% of 50), 80/20 class balance | |
| (array([0, 1]), array([1, 4])) # 5 items (10% of 50), 80/20 class balance | |
| ] | |
| ``` | |
| """ | |
| values, proportions = np.asarray(values), np.asarray(proportions) | |
| if values.ndim != 1: | |
| raise ValueError('Values must be 1D array') | |
| if proportions.ndim != 1: | |
| raise ValueError('Proportions must be 1D array') | |
| if not np.isclose(np.sum(proportions), 1): | |
| raise ValueError('Proportions must sum to 1') | |
| idx = np.arange(len(values)) | |
| splits = [] | |
| return _get_stratified_split(splits, idx, values, proportions) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment