Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save carloocchiena/73e9af745ed96cd0ebf08df86c6ccd34 to your computer and use it in GitHub Desktop.
Save carloocchiena/73e9af745ed96cd0ebf08df86c6ccd34 to your computer and use it in GitHub Desktop.
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error
# Drop NaNs from differencing (if any)
daily_data = df['Value'].resample('D').mean()
daily_data = daily_data.asfreq('D') # daily average
daily_data = daily_data.fillna(method='ffill') # safety fill if any missing
subset = daily_data['2024-01-01':'2024-03-31']
# Fit a SARIMA model
model = SARIMAX(
subset,
order=(1, 1, 1), # p, d, q
seasonal_order=(1, 1, 1, 7), # P, D, Q, S (weekly seasonality)
enforce_stationarity=False,
enforce_invertibility=False
)
results = model.fit(disp=False)
print(results.summary())
results.plot_diagnostics(figsize=(15, 8))
plt.tight_layout()
plt.show()
forecast = results.get_forecast(steps=168)
forecast_df = forecast.summary_frame()
# Plot
plt.figure(figsize=(14, 5))
plt.plot(daily_data[-168:], label='Observed (last 7 days)')
plt.plot(forecast_df['mean'], label='Forecast (next 7 days)', color='orange')
plt.fill_between(forecast_df.index,
forecast_df['mean_ci_lower'],
forecast_df['mean_ci_upper'],
color='orange', alpha=0.3)
plt.title('SARIMA Forecast')
plt.xlabel('Date')
plt.ylabel('Value')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment