Created
August 4, 2025 15:39
-
-
Save topepo/f2aa97f08c600a293003ec1f972d345d to your computer and use it in GitHub Desktop.
prototype for how agua can work in tune 2.0.0
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(tune) | |
library(recipes) | |
library(parsnip) | |
library(yardstick) | |
library(rsample) | |
library(workflows) | |
library(modeldata) | |
library(dials) | |
library(ggplot2) | |
library(dplyr) | |
library(agua) | |
# ------------------------------------------------------------------------------ | |
# tidymodels_prefer() | |
# theme_set(theme_bw()) | |
options(pillar.advice = FALSE, pillar.min_title_chars = Inf) | |
h2o_start() | |
# ------------------------------------------------------------------------------ | |
set.seed(1) | |
dat <- sim_classification(num_samples = 1000, intercept = -10) | |
split <- initial_split(dat, strata = class) | |
tr_dat <- training(split) | |
te_dat <- testing(split) | |
# from tune | |
vec_list_rowwise <- function(x) { | |
vctrs::vec_split(x, by = 1:nrow(x))$val | |
} | |
# for our example: | |
tr_info <- list(data = tr_dat, ind = as.integer(split)) | |
te_info <- list(data = te_dat, ind = as.integer(split, "assessment")) | |
# ------------------------------------------------------------------------------ | |
# fmt: skip | |
cls_mtr <- metric_set(kap, brier_class, roc_auc, pr_auc, sensitivity, specificity, | |
mn_log_loss, mcc) | |
# ------------------------------------------------------------------------------ | |
glmn_spec <- logistic_reg(penalty = tune(), mixture = tune()) %>% | |
set_engine("h2o") | |
rec <- recipe(class ~ ., data = tr_dat) |> | |
step_normalize(all_numeric_predictors()) | |
glmn_wflow <- workflow(rec, glmn_spec) | |
glmn_param <- | |
glmn_wflow |> | |
extract_parameter_set_dials() | |
glmn_grid <- glmn_param |> grid_space_filling(size = 4) | |
agua_grid <- function( | |
wflow, | |
grid, | |
outcome_name, | |
train_info, | |
val_info, | |
resample_label | |
) { | |
model_mode <- | |
wflow |> | |
hardhat::extract_spec_parsnip() |> | |
purrr::pluck("mode") | |
# ---------------------------------------------------------------------------- | |
# grid things | |
orig_names <- names(grid) | |
model_param_names_h2o <- agua:::extract_model_param_names_h2o( | |
orig_names, | |
wflow | |
) | |
parsnip_to_h2o <- orig_names | |
names(parsnip_to_h2o) <- model_param_names_h2o | |
h2o_to_parsnip <- model_param_names_h2o | |
names(h2o_to_parsnip) <- orig_names | |
grid_by_row <- vec_list_rowwise(grid) | |
h2o_hyper_params <- | |
grid |> | |
dplyr::rename(!!!parsnip_to_h2o) |> | |
as.list() | |
# ---------------------------------------------------------------------------- | |
# data things | |
# extract outcome and predictor names (used by h2o.grid) | |
predictor_names <- colnames(train_info$data) | |
predictor_names <- predictor_names[predictor_names != outcome_name] | |
h2o_training_frame <- agua:::as_h2o(train_info$data, "training_frame") | |
h2o_val_frame <- agua:::as_h2o(val_info$data, "val_frame") | |
# ---------------------------------------------------------------------------- | |
h2o_algo <- agua:::extract_h2o_algorithm(wflow) # not exported | |
h2o_search_criteria <- | |
if (length(h2o_hyper_params) > 1) { | |
list(strategy = "Sequential") | |
} else { | |
NULL | |
} | |
# on.exit | |
h2o_res <- h2o::h2o.grid( | |
h2o_algo, | |
x = predictor_names, | |
y = outcome_name, | |
training_frame = h2o_training_frame$data, | |
hyper_params = h2o_hyper_params, | |
parallelism = 0, | |
search_criteria = h2o_search_criteria | |
) | |
# ---------------------------------------------------------------------------- | |
h2o_model_ids <- as.character(h2o_res@model_ids) | |
h2o_models <- purrr::map(h2o_model_ids, agua:::h2o_get_model) | |
val_truth <- val_info$data[outcome_name] | |
h2o_predictions <- purrr::map( | |
h2o_models, | |
agua:::pull_h2o_predictions, # not exported | |
val_frame = h2o_val_frame$data, | |
val_truth = val_truth, | |
fold_id = resample_label, | |
orig_rows = val_info$ind, | |
mode = model_mode | |
) |> | |
purrr::map2(grid_by_row, ~ vctrs::vec_cbind(.y, .x)) | |
h2o_predictions | |
} | |
res <- agua_grid( | |
glmn_wflow, | |
grid = glmn_grid, | |
outcome_name = "class", | |
train_info = tr_info, | |
val_info = te_info, | |
resample_label = tibble(id = "TODO") | |
) | |
res |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment