Skip to content

Instantly share code, notes, and snippets.

@justinhchae
Created February 28, 2021 15:33
Show Gist options
  • Save justinhchae/d2a2dc8b71b5f5fbbb0f7eabf68b6850 to your computer and use it in GitHub Desktop.
Save justinhchae/d2a2dc8b71b5f5fbbb0f7eabf68b6850 to your computer and use it in GitHub Desktop.
def run_arima(chunked_data, price_col='y', n_prediction_units=1):
# consume chunked data from https://gist.github.com/justinhchae/13d246e8e2e2d521a8d2cce20eb09a09
# supress trivial warnings from ARIMA
warnings.simplefilter('ignore', ConvergenceWarning)
# initialize a list to hold results (a list of dataframes)
results = []
# numerate through a list of chunked tuples, each having a pair of dataframes
for idx, (x_i, y_i) in enumerate(chunked_data):
# create ARIMA model based on x_i values
m = ARIMA(x_i[price_col].values, order=(0, 1, 0))
# fit the model
m_fit = m.fit()
# forecast for n_prediction_units
yhat = m_fit.forecast(steps=n_prediction_units)
# return a dataframe of targets and predictions of len targets
y_i['yhat'] = yhat[:len(y_i)]
# save results to a list and then return the list
results.append(y_i)
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment