Last active
November 7, 2017 14:52
-
-
Save kwilcox/9392156b7b7c14d56129b269025d4210 to your computer and use it in GitHub Desktop.
Calculating the "best in forecast" for a model.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forecast_member(dt, timestep_period, forecast_period, runtimes, base_date=None): | |
# Go back as far as we could go to still be in a forecast | |
# The datetime won't be in any future forecasts | |
# Only supporting hourly runtimes so set minutes to 0 so we can match on the "hourly" runtimes | |
starting = dt - forecast_period | |
starting = starting.replace(minute=0, second=0, microsecond=0) | |
compare_dt = copy(dt) | |
if timestep_period >= timedelta(days=1): | |
compare_dt = compare_dt.replace(hour=0) | |
if timestep_period >= timedelta(hours=1): | |
compare_dt = compare_dt.replace(minute=0) | |
if timestep_period >= timedelta(minutes=1): | |
compare_dt = compare_dt.replace(second=0, microsecond=0) | |
# Don't return forecasts that are after base_date. If not passed in, don't return forecasts | |
# after right now. If False, return all forecasts regardless of the date. | |
if base_date is False: | |
base_date = datetime.max | |
if base_date is None: | |
base_date = datetime.utcnow() | |
assert isinstance(base_date, datetime) | |
# Figure out all of the possible forecast runtimes that we could be a member of | |
forecast_runtimes = [] | |
while starting <= dt: | |
# Test against any future limit | |
if starting >= base_date: | |
break # starting only increases so we can just break out of the while now | |
if starting.hour in runtimes: | |
forecast_runtimes.append(starting) | |
starting += timedelta(hours=1) | |
for fs in forecast_runtimes: | |
# Get list of all datetimes in this forecast | |
fs_start = copy(fs) | |
fs_end = fs_start + forecast_period | |
timesteps = [] | |
while fs_start <= fs_end: | |
timesteps.append(fs_start) | |
fs_start += timestep_period | |
# See if we are a member of this forecast | |
if compare_dt in timesteps: | |
ind = timesteps.index(compare_dt) | |
yield (fs, ind, timesteps[ind]) | |
def best_in_forecast(timestep_period, forecast_period, runtimes, starting, ending, base_date=None): | |
while starting <= ending: | |
fms = forecast_member(dt=starting, | |
timestep_period=timestep_period, | |
forecast_period=forecast_period, | |
runtimes=runtimes, | |
base_date=base_date) | |
fms = list(fms) | |
logger.debug('best_in_forecast fms: {}, starting: {}, ending: {}'.format(fms, starting, ending)) | |
if fms: | |
# Sort by the last (best) forecast | |
bests = sorted(fms, key=itemgetter(0), reverse=True) | |
yield bests[0] | |
starting += timestep_period | |
def test_actual_forecast(self): | |
timestep_period = timedelta(hours=1) | |
forecast_period = timedelta(hours=28) | |
runtimes = [ 0, 12 ] | |
st = datetime(2017, 6, 12, 19, 49) | |
et = datetime(2017, 6, 13, 0, 34) | |
x = forecast_member(dt=st, | |
timestep_period=timestep_period, | |
forecast_period=forecast_period, | |
runtimes=runtimes) | |
x = list(x) | |
assert x == [ | |
(datetime(2017, 6, 12, 0), 19, st.replace(minute=0)), | |
(datetime(2017, 6, 12, 12), 7, st.replace(minute=0)) | |
] | |
y = best_in_forecast(starting=st, | |
ending=et, | |
timestep_period=timestep_period, | |
forecast_period=forecast_period, | |
runtimes=runtimes) | |
y = list(y) | |
assert y == [ | |
(datetime(2017, 6, 12, 12, 0), 7, datetime(2017, 6, 12, 19, 0)), | |
(datetime(2017, 6, 12, 12, 0), 8, datetime(2017, 6, 12, 20, 0)), | |
(datetime(2017, 6, 12, 12, 0), 9, datetime(2017, 6, 12, 21, 0)), | |
(datetime(2017, 6, 12, 12, 0), 10, datetime(2017, 6, 12, 22, 0)), | |
(datetime(2017, 6, 12, 12, 0), 11, datetime(2017, 6, 12, 23, 0)), | |
] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment