Created
February 28, 2021 15:33
-
-
Save justinhchae/d2a2dc8b71b5f5fbbb0f7eabf68b6850 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def run_arima(chunked_data, price_col='y', n_prediction_units=1): | |
# consume chunked data from https://gist.github.com/justinhchae/13d246e8e2e2d521a8d2cce20eb09a09 | |
# supress trivial warnings from ARIMA | |
warnings.simplefilter('ignore', ConvergenceWarning) | |
# initialize a list to hold results (a list of dataframes) | |
results = [] | |
# numerate through a list of chunked tuples, each having a pair of dataframes | |
for idx, (x_i, y_i) in enumerate(chunked_data): | |
# create ARIMA model based on x_i values | |
m = ARIMA(x_i[price_col].values, order=(0, 1, 0)) | |
# fit the model | |
m_fit = m.fit() | |
# forecast for n_prediction_units | |
yhat = m_fit.forecast(steps=n_prediction_units) | |
# return a dataframe of targets and predictions of len targets | |
y_i['yhat'] = yhat[:len(y_i)] | |
# save results to a list and then return the list | |
results.append(y_i) | |
return results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment