Skip to content

Instantly share code, notes, and snippets.

@eric-czech
Last active August 28, 2019 22:02
Show Gist options
  • Select an option

  • Save eric-czech/260de02335007250b06582226bbe73b7 to your computer and use it in GitHub Desktop.

Select an option

Save eric-czech/260de02335007250b06582226bbe73b7 to your computer and use it in GitHub Desktop.
(Python/Sklearn) Function to run stratified split into arbitrary number of label subsets with some desired proportions
from sklearn.model_selection import train_test_split
import numpy as np
def _get_stratified_split(splits, idx, values, proportions):
# Return immediately if there are fewer than 2 proportions use in splits
if len(proportions) <= 1:
splits.append(idx)
return splits
# Determine proportion of items to extract for this split and associated
# indexes for that proportion of the items (and add to final result)
p = proportions[0]
idx_split = train_test_split(idx, stratify=values[idx], train_size=p)[0]
splits.append(idx_split)
# Recurse with this split removed from indexes and proportions rescaled
# to reflect the fraction of items in the subset equivalent to the desired
# fraction in the original set
idx = np.setdiff1d(idx, idx_split)
proportions = (1./(1-p)) * proportions[1:]
return _get_stratified_split(splits, idx, values, proportions)
def get_stratified_split(values, proportions):
"""Split values into arbitrary sets with target size while using stratification
Args:
values: Labels to use for stratification
proportions: Sequence containing desired proportions of resulting splits
Returns:
List of numpy arrays with length equal to length of proportions and values
equivalent to indexes associated with each split (the number of these indexes
should roughly account for the desired portion of samples and each split
should be comprised of roughly equal label frequencies)
Example:
```
# Binary labels array
values = (np.arange(50) > 10).astype(int)
# Split into 3 groups where first is largest at 60% of all elements (and next two are smaller)
proportions = [.6, .3, .1]
[np.unique(y[s], return_counts=True) for s in get_stratified_split(values, proportions)]
>> [
(array([0, 1]), array([ 7, 23])), # 30 items (60% of 50), ~80/20 class balance
(array([0, 1]), array([ 3, 12])), # 15 items (30% of 50), 80/20 class balance
(array([0, 1]), array([1, 4])) # 5 items (10% of 50), 80/20 class balance
]
```
"""
values, proportions = np.asarray(values), np.asarray(proportions)
if values.ndim != 1:
raise ValueError('Values must be 1D array')
if proportions.ndim != 1:
raise ValueError('Proportions must be 1D array')
if not np.isclose(np.sum(proportions), 1):
raise ValueError('Proportions must sum to 1')
idx = np.arange(len(values))
splits = []
return _get_stratified_split(splits, idx, values, proportions)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment