Skip to content

Instantly share code, notes, and snippets.

@mustafa-qamaruddin
Created May 4, 2019 17:28
Show Gist options
  • Save mustafa-qamaruddin/61de643493d5507c95fbe899e5fa3298 to your computer and use it in GitHub Desktop.
Save mustafa-qamaruddin/61de643493d5507c95fbe899e5fa3298 to your computer and use it in GitHub Desktop.
class BlockingTimeSeriesSplit():
def __init__(self, n_splits):
self.n_splits = n_splits
def get_n_splits(self, X, y, groups):
return self.n_splits
def split(self, X, y=None, groups=None):
n_samples = len(X)
k_fold_size = n_samples // self.n_splits
indices = np.arange(n_samples)
margin = 0
for i in range(self.n_splits):
start = i * k_fold_size
stop = start + k_fold_size
mid = int(0.8 * (stop - start)) + start
yield indices[start: mid], indices[mid + margin: stop]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment