Last active
March 10, 2023 06:38
-
-
Save tomonori-masui/0f6fc07571a24de27d771bd50521ca74 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sktime.forecasting.compose import make_reduction, TransformedTargetForecaster | |
from sktime.forecasting.model_selection import ExpandingWindowSplitter, ForecastingGridSearchCV | |
from sktime.performance_metrics.forecasting import MeanAbsolutePercentageError | |
import lightgbm as lgb | |
def create_forecaster(): | |
# creating forecaster with LightGBM | |
regressor = lgb.LGBMRegressor() | |
forecaster = make_reduction(regressor, window_length=5, strategy="recursive") | |
return forecaster | |
def grid_serch_forecaster(train, test, forecaster, param_grid): | |
# Grid search on window_length | |
cv = ExpandingWindowSplitter(initial_window=int(len(train) * 0.7)) | |
gscv = ForecastingGridSearchCV( | |
forecaster, strategy="refit", cv=cv, param_grid=param_grid, | |
scoring=MeanAbsolutePercentageError(symmetric=True) | |
) | |
gscv.fit(train) | |
print(f"best params: {gscv.best_params_}") | |
# forecasting | |
fh = np.arange(len(test)) + 1 | |
y_pred = gscv.predict(fh=fh) | |
mae, mape = plot_forecast(train, test, y_pred) | |
return mae, mape | |
param_grid = { | |
"window_length": [5, 10, 15, 20, 25, 30] # parameter set to be grid searched | |
} | |
forecaster = create_forecaster() | |
sun_lgb_mae, sun_lgb_mape = grid_serch_forecaster( | |
sun_train, sun_test, forecaster, param_grid | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@hiteshgupta2507
Only the WPI data have that conversion in that blog post.
Nile dataset is not indexed with datetime values. It just has numeric values of years, hence it does not require frequency on its index.