Skip to content

Instantly share code, notes, and snippets.

@topepo
Created August 4, 2025 15:39
Show Gist options
  • Save topepo/f2aa97f08c600a293003ec1f972d345d to your computer and use it in GitHub Desktop.
Save topepo/f2aa97f08c600a293003ec1f972d345d to your computer and use it in GitHub Desktop.
prototype for how agua can work in tune 2.0.0
library(tune)
library(recipes)
library(parsnip)
library(yardstick)
library(rsample)
library(workflows)
library(modeldata)
library(dials)
library(ggplot2)
library(dplyr)
library(agua)
# ------------------------------------------------------------------------------
# tidymodels_prefer()
# theme_set(theme_bw())
options(pillar.advice = FALSE, pillar.min_title_chars = Inf)
h2o_start()
# ------------------------------------------------------------------------------
set.seed(1)
dat <- sim_classification(num_samples = 1000, intercept = -10)
split <- initial_split(dat, strata = class)
tr_dat <- training(split)
te_dat <- testing(split)
# from tune
vec_list_rowwise <- function(x) {
vctrs::vec_split(x, by = 1:nrow(x))$val
}
# for our example:
tr_info <- list(data = tr_dat, ind = as.integer(split))
te_info <- list(data = te_dat, ind = as.integer(split, "assessment"))
# ------------------------------------------------------------------------------
# fmt: skip
cls_mtr <- metric_set(kap, brier_class, roc_auc, pr_auc, sensitivity, specificity,
mn_log_loss, mcc)
# ------------------------------------------------------------------------------
glmn_spec <- logistic_reg(penalty = tune(), mixture = tune()) %>%
set_engine("h2o")
rec <- recipe(class ~ ., data = tr_dat) |>
step_normalize(all_numeric_predictors())
glmn_wflow <- workflow(rec, glmn_spec)
glmn_param <-
glmn_wflow |>
extract_parameter_set_dials()
glmn_grid <- glmn_param |> grid_space_filling(size = 4)
agua_grid <- function(
wflow,
grid,
outcome_name,
train_info,
val_info,
resample_label
) {
model_mode <-
wflow |>
hardhat::extract_spec_parsnip() |>
purrr::pluck("mode")
# ----------------------------------------------------------------------------
# grid things
orig_names <- names(grid)
model_param_names_h2o <- agua:::extract_model_param_names_h2o(
orig_names,
wflow
)
parsnip_to_h2o <- orig_names
names(parsnip_to_h2o) <- model_param_names_h2o
h2o_to_parsnip <- model_param_names_h2o
names(h2o_to_parsnip) <- orig_names
grid_by_row <- vec_list_rowwise(grid)
h2o_hyper_params <-
grid |>
dplyr::rename(!!!parsnip_to_h2o) |>
as.list()
# ----------------------------------------------------------------------------
# data things
# extract outcome and predictor names (used by h2o.grid)
predictor_names <- colnames(train_info$data)
predictor_names <- predictor_names[predictor_names != outcome_name]
h2o_training_frame <- agua:::as_h2o(train_info$data, "training_frame")
h2o_val_frame <- agua:::as_h2o(val_info$data, "val_frame")
# ----------------------------------------------------------------------------
h2o_algo <- agua:::extract_h2o_algorithm(wflow) # not exported
h2o_search_criteria <-
if (length(h2o_hyper_params) > 1) {
list(strategy = "Sequential")
} else {
NULL
}
# on.exit
h2o_res <- h2o::h2o.grid(
h2o_algo,
x = predictor_names,
y = outcome_name,
training_frame = h2o_training_frame$data,
hyper_params = h2o_hyper_params,
parallelism = 0,
search_criteria = h2o_search_criteria
)
# ----------------------------------------------------------------------------
h2o_model_ids <- as.character(h2o_res@model_ids)
h2o_models <- purrr::map(h2o_model_ids, agua:::h2o_get_model)
val_truth <- val_info$data[outcome_name]
h2o_predictions <- purrr::map(
h2o_models,
agua:::pull_h2o_predictions, # not exported
val_frame = h2o_val_frame$data,
val_truth = val_truth,
fold_id = resample_label,
orig_rows = val_info$ind,
mode = model_mode
) |>
purrr::map2(grid_by_row, ~ vctrs::vec_cbind(.y, .x))
h2o_predictions
}
res <- agua_grid(
glmn_wflow,
grid = glmn_grid,
outcome_name = "class",
train_info = tr_info,
val_info = te_info,
resample_label = tibble(id = "TODO")
)
res
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment