Skip to content

Instantly share code, notes, and snippets.

@cimentadaj
Created February 5, 2020 15:46
Show Gist options
  • Save cimentadaj/788cf5d6ad9dca79617c4abd582439e0 to your computer and use it in GitHub Desktop.
Save cimentadaj/788cf5d6ad9dca79617c4abd582439e0 to your computer and use it in GitHub Desktop.
library(AmesHousing)
library(tidymodels)
ames <-
make_ames() %>%
# Remove quality-related predictors
select(-matches("Qu"))
############################# Data Splitting ##################################
###############################################################################
set.seed(4595)
data_split <- initial_split(ames, strata = c("Sale_Price"))
ames_train <- training(data_split)
ames_test <- testing(data_split)
############################# Preprocessing ###################################
###############################################################################
mod_rec <-
recipe(Sale_Price ~ Longitude + Latitude + Neighborhood,
data = ames_train) %>%
step_log(Sale_Price, base = 10) %>%
step_other(Neighborhood, threshold = 0.05) %>%
step_dummy(all_nominal())
## prep is to calculate statistics for transformation
## juice is the training data with preprocessing
## bake is to apply the recipe to a new testing data
############################# Modelling #######################################
###############################################################################
lm_mod <-
linear_reg(penalty = tune(), mixture = tune()) %>%
set_engine("glmnet")
ml_wflow <-
workflow() %>%
add_recipe(mod_rec) %>%
add_model(lm_mod)
############################# Model tuning ####################################
###############################################################################
grid_params <-
ml_wflow %>%
parameters() %>%
grid_random(size = 10)
grid_params <- expand.grid(
penalty = 10 ^ seq(-3, -1, length = 20),
mixture = (0:5) / 5
)
cv_splits <- vfold_cv(ames_train)
res <-
ml_wflow %>%
tune_grid(resamples = cv_splits,
grid = grid_params,
metrics = metric_set(rmse))
rmse_vals <-
res %>%
collect_metrics()
rmse_vals %>%
mutate(mixture = format(mixture, digits = 1)) %>%
ggplot(aes(x = penalty, y = mean, col = mixture, group = mixture)) +
geom_line() +
geom_point() +
scale_x_log10()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment