Skip to content

Instantly share code, notes, and snippets.

@twolodzko
Last active October 1, 2018 10:36
Show Gist options
  • Select an option

  • Save twolodzko/27b4e951f6f9d0dbfddaa98ec3a953b8 to your computer and use it in GitHub Desktop.

Select an option

Save twolodzko/27b4e951f6f9d0dbfddaa98ec3a953b8 to your computer and use it in GitHub Desktop.
Split to train and test samples by clusters
import numpy as np
def train_test_split(*arrays, test_size, random_state, clusters):
'''Split to train and test samples by clusters
Parameters
----------
test_size : float, 0 < test_size < 1
fraction of clusters to include in test set
random_state : int
seed for np.random.RandomState
clusters : array
array of clasters used to split the data
Examples
--------
>>> x = np.array([1,2,3,4,5,6])
>>> y = np.array([1,1,1,0,0,0])
>>> c = np.array([1,1,2,2,3,3])
>>> x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=1/3,
random_state=42, clusters=c)
>>> x_train, x_test, y_train, y_test
(array([3, 4, 5, 6]), array([1, 2]), array([1, 0, 0, 0]), array([1, 1]))
'''
rng = np.random.RandomState(random_state)
unique_clusters = np.unique(clusters)
n = len(unique_clusters)
n_test = int(n*test_size)
test_clusters = set(rng.permutation(unique_clusters)[:n_test])
test_idx = np.array([v in test_clusters for v in clusters])
out = []
for arr in arrays:
out.append(arr[~test_idx])
out.append(arr[test_idx])
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment