Skip to content

Instantly share code, notes, and snippets.

@cimentadaj
Created January 31, 2019 15:03
Show Gist options
  • Save cimentadaj/19128a958bcb6ffe21d66f28ff8e1015 to your computer and use it in GitHub Desktop.
Save cimentadaj/19128a958bcb6ffe21d66f28ff8e1015 to your computer and use it in GitHub Desktop.
# Redo of this post https://www.brodrigues.co/blog/2018-11-25-tidy_cv/
library("tidyverse")
library("tidymodels")
library("brotools")
library("mlbench")
set.seed(231451)
data("BostonHousing2")
head(BostonHousing2)
boston <-
BostonHousing2 %>%
select(-medv, -town, -lon, -lat) %>%
rename(price = cmedv) %>%
as_tibble()
train_test_split <- initial_split(boston, prop = 0.9)
housing_train <- training(train_test_split)
housing_test <- testing(train_test_split)
validation_data <- mc_cv(housing_train, prop = 0.9, times = 30)
simple_recipe <- function(dt) {
dt %>%
recipe(price ~ .) %>%
step_center(all_numeric()) %>%
step_scale(all_numeric()) %>%
step_dummy(all_nominal())
}
train_rec <- prep(simple_recipe(housing_train))
test_data <- bake(train_rec, new_data = housing_test)
train_data <- juice(train_rec)
mod_obj <- linear_reg() %>% set_engine("lm")
estimate <-
mod_obj %>%
fit(formula(train_rec), data = train_data) %>%
predict_numeric(new_data = test_data)
predict_data <- tibble(truth = test_data$price, estimate = estimate)
rmse(predict_data, truth = truth, estimate = estimate)
fit_rf <- function(mtry, trees, rsplit, id) {
train_raw <- analysis(rsplit)
test_raw <- assessment(rsplit)
train_rec <- prep(simple_recipe(train_raw))
train_data <- juice(train_rec)
test_data <- bake(train_rec, new_data = test_raw)
mod_obj <-
rand_forest(mode = 'regression',
trees = trees,
mtry = mtry) %>%
set_engine("ranger", importance = "impurity")
model_fit <-
fit(object = mod_obj,
formula = formula(train_rec),
data = train_data)
output <- tibble(
id = id,
truth = test_data$price,
estimate = predict_numeric(model_fit, new_data = test_data)
)
output
}
complete_predictions <- map2_df(validation_data$splits, validation_data$id, ~ fit_rf(3, 200, .x, .y))
complete_predictions %>%
group_by(id) %>%
rmse(truth, estimate) %>%
summarize(avg_rmse = mean(.estimate),
sd = sd(.estimate),
ci_low = avg_rmse - (1.96 * sd),
ci_high = avg_rmse + (1.96 * sd))
tuning <- function(x, rsplit) {
mtry <- x[1]
trees <- x[2]
complete_predictions <- map2_df(rsplit$splits, rsplit$id, ~ fit_rf(mtry, trees, .x, .y))
complete_predictions %>%
group_by(id) %>%
rmse(truth, estimate) %>%
summarize(avg_rmse = mean(.estimate),
sd = sd(.estimate),
ci_low = avg_rmse - (1.96 * sd),
ci_high = avg_rmse + (1.96 * sd)) %>%
pull(avg_rmse)
}
tuning(3, 200, validation_data)
grid_search <- crossing(mtry = 3, trees = 200:300)
final_results <-
grid_search %>%
mutate(rmse = map2_dbl(mtry, trees, ~ tuning(.x, .y, validation_data)))
final_results %>%
ggplot(aes(x = trees, y = rmse)) +
geom_line(colour = "#82518c") +
theme_blog() +
ggtitle("RMSE for mtry = 3")
library("mlrMBO")
fn <- makeSingleObjectiveFunction(name = "tuning",
fn = tuning,
par.set = makeParamSet(makeIntegerParam("x1", lower = 3, upper = 8),
makeIntegerParam("x2", lower = 50, upper = 500)))
library(lhs)# for randomLHS
des <- generateDesign(n = 5L * 2L, getParamSet(fn), fun = randomLHS)
surrogate <- makeLearner("regr.ranger", predict.type = "se", keep.inbag = TRUE)
# Set general controls
ctrl <- makeMBOControl()
ctrl <- setMBOControlTermination(ctrl, iters = 10L)
ctrl <- setMBOControlInfill(ctrl, crit = makeMBOInfillCritEI())
result <- mbo(fn, des, surrogate, ctrl, more.args = list("validation_data" = validation_data))
training_rec <- prep(simple_recipe(housing_train), testing = housing_train)
train_data <- bake(training_rec, newdata = housing_train)
final_model <- rand_forest(mtry = 6, trees = 381) %>%
set_engine("ranger", importance = 'impurity') %>%
fit(price ~ ., data = train_data)
price_predict <- predict(final_model, new_data = select(test_data, -price))
cbind(price_predict * sd(housing_train$price) + mean(housing_train$price),
housing_test$price)
tibble::tibble("truth" = test_data$price,
"prediction" = unlist(price_predict)) %>%
rmse(truth, prediction)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment