Created
October 1, 2019 17:49
-
-
Save mattsgithub/12ac44606587d18da18471df03593a74 to your computer and use it in GitHub Desktop.
Cross validation is a bit more complicated when dealing with time series. This class provides dates to train, validate, and test on.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SampleForwardChainCV(object): | |
def __init__(self, | |
dates, | |
obs_count, | |
min_start_date=None, | |
max_end_date=None, | |
n_min_train_obs=20, | |
n_min_validate_obs=20, | |
n_min_test_obs=20): | |
self.dates = sorted(list(set(dates))) | |
self.obs_count = obs_count | |
self.min_start_date = min_start_date or self.dates[0] | |
self.max_end_date = max_end_date or self.dates[-1] | |
self.n_min_train_obs = n_min_train_obs | |
self.n_min_validate_obs = n_min_validate_obs | |
self.n_min_test_obs = n_min_test_obs | |
def split(self): | |
def get_end_index(start_index, i, min_obs): | |
"""Return the next index | |
given the sum of obs in | |
each element exceed the min | |
required | |
""" | |
n_obs = sum(self.obs_count[start_index:i + 1]) | |
while n_obs < min_obs: | |
if self.dates[i] >= self.max_end_date: | |
raise IndexError() | |
i += 1 | |
n_obs = sum(self.obs_count[start_index:i + 1]) | |
return i | |
# Start at this index | |
train_start_index = self.dates.index(self.min_start_date) | |
# Iterate over all dates | |
for i in range(train_start_index, len(self.dates)): | |
try: | |
train_end_index = get_end_index(train_start_index, i, self.n_min_train_obs) | |
valid_start_index = train_end_index + 1 | |
valid_end_index = get_end_index(valid_start_index, | |
valid_start_index, | |
self.n_min_validate_obs) | |
test_start_index = valid_end_index + 1 | |
test_end_index = get_end_index(test_start_index, | |
test_start_index, | |
self.n_min_test_obs) | |
yield (self.dates[train_start_index], self.dates[train_end_index]), \ | |
(self.dates[valid_start_index], self.dates[valid_end_index]), \ | |
(self.dates[test_start_index], self.dates[test_end_index]) | |
except: | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment