Last active
March 18, 2026 17:11
-
-
Save vpnagraj/59fa609c5adf47c8c7a5b156eb261be7 to your computer and use it in GitHub Desktop.
Demonstration of using "workflow sets" with tidymodels in R
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ## script to demonstrate running sets of workflows | |
| ## adapted from the workflowsets package vignette | |
| ## https://workflowsets.tidymodels.org/articles/evaluating-different-predictor-sets.html | |
| ## uses a credit scoring dataset from the modeldata package | |
| ## original source: https://github.com/gastonstat/CreditScoring | |
| ## load "meta" packages (tidyverse and tidymodels both load multiple packages under the hood) | |
| library(tidymodels) | |
| library(tidyverse) | |
| ## use data from the modeldata package to demo | |
| ## this is a credit scoring classification dataset | |
| ?modeldata::credit_data | |
| data(credit_data, package = "modeldata") | |
| ## take a peek at the data | |
| ## check the outcome, predictors, and column types before modeling | |
| glimpse(credit_data) | |
| ## configure logistic regression "engine" | |
| ## this is a parsnip model spec ... not a fitted model yet | |
| ## logistic_reg() defines the model type and glm defines the fitting engine | |
| logreg_model <- | |
| logistic_reg() %>% | |
| set_engine("glm") | |
| ## set random seed for reproducibility | |
| ## this makes the split and fold assignments reproducible | |
| set.seed(123) | |
| ## create split for training and testing | |
| ## this is an rsample split object that stores the partition | |
| ## strata helps preserve class balance for this classification problem | |
| to_split <- initial_split(credit_data, strata = Status) | |
| ## the split object tells you how many training/testing/total observations | |
| to_split | |
| ## resample the training data with cross validation | |
| ## cross validation happens only on the training portion | |
| ## each fold has an analysis set for fitting and an assessment set for validation | |
| folds <- vfold_cv(training(to_split), strata = Status) | |
| ## what does the cross validation object look like? | |
| ## this is a tibble-like resampling object with one row per fold | |
| ## the splits column is a list-column of rsplit objects | |
| glimpse(folds) | |
| ## inspect the split objects directly | |
| ## this is the resampling plan, not model output | |
| ## each split defines which rows are in analysis vs assessment for that fold | |
| folds$splits | |
| ## create list of formulas for models that iteratively leave out each variable | |
| ## including the full_model = TRUE for to get metrics for a model with all predictors kept in | |
| ## this gives the full model plus versions that leave out one predictor at a time | |
| formulas <- leave_var_out_formulas(Status ~ ., data = credit_data, full_model = TRUE) | |
| ## what does this formulas object look like? | |
| ## think of this as the "menu" of predictor sets that will feed the workflow set | |
| glimpse(formulas) | |
| ## create a set of multiple workflows | |
| ## each workflow combines one formula with the same logistic regression spec | |
| credit_workflows <- | |
| workflow_set( | |
| preproc = formulas, | |
| models = list(logistic = logreg_model) | |
| ) | |
| ## what comes out? | |
| ## this is the catalog of candidate workflows we are about to evaluate | |
| glimpse(credit_workflows) | |
| ## fit the workflows across cross-validation folds | |
| ## workflow_map() applies the same resampling process to each workflow in the set | |
| credit_fits <- | |
| credit_workflows %>% | |
| workflow_map("fit_resamples", resamples = folds) | |
| ## extract accuracy values | |
| ## keep the fold-level values with collect_metrics rather than summarizing (averaging) across folds | |
| ## this preserves the distribution so we can compare variability later | |
| acc_values <- | |
| credit_fits %>% | |
| collect_metrics(summarize = FALSE) %>% | |
| filter(.metric == "accuracy") %>% | |
| mutate(wflow_id = gsub("_logistic", "", wflow_id)) | |
| ## accuracy values are retained for each workflow (i.e., model formula) and CV fold | |
| glimpse(acc_values) | |
| ## access the full model estimate for comparison below | |
| ## keep the fold id so comparisons line up within the same resample | |
| full_model <- | |
| acc_values %>% | |
| filter(wflow_id == "everything") %>% | |
| select(full_model = .estimate, id) | |
| ## we can look at the metrics just for CV folds of "everything" workflow (i.e., full model) | |
| glimpse(full_model) | |
| ## now get the accuracy values for all the other model formulations | |
| ## join with the full model accuracy values by fold | |
| ## compute the drop in performance relative to the full model | |
| differences <- | |
| acc_values %>% | |
| filter(wflow_id != "everything") %>% | |
| full_join(full_model, by = "id") %>% | |
| mutate(performance_drop = full_model - .estimate) | |
| ## after the join we get the difference in performance for each model formulation ... | |
| ## ... this will be matched at each CV fold | |
| glimpse(differences) | |
| ## compute summary stats across folds | |
| summary_stats <- | |
| differences %>% | |
| group_by(wflow_id) %>% | |
| summarize( | |
| std_err = sd(performance_drop, na.rm = TRUE) / sqrt(sum(!is.na(performance_drop))), | |
| performance_drop = mean(performance_drop, na.rm = TRUE), | |
| lower = performance_drop - qnorm(0.975) * std_err, | |
| upper = performance_drop + qnorm(0.975) * std_err, | |
| .groups = "drop" | |
| ) %>% | |
| mutate( | |
| wflow_id = factor(wflow_id), | |
| wflow_id = reorder(wflow_id, performance_drop) | |
| ) | |
| ## plot the mean performance drop with error bars | |
| ## points farther right show a bigger drop in accuracy when that variable is removed | |
| ggplot(summary_stats, aes(x = performance_drop, y = wflow_id)) + | |
| geom_point(size = 2) + | |
| geom_errorbar(aes(xmin = lower, xmax = upper), width = .25) + | |
| ylab("") + | |
| theme_minimal() | |
| ## but what about the test dataset held out above? | |
| ## that wasn't used to arrive at any of the metrics in the workflow set procedure | |
| ## well ... that is by design. | |
| ## cross validation / accuracy analysis like the example above is useful for *model selection* | |
| ## the test set can be held out for final *model evaluation* | |
| ## in other words ... based on the leave-one-out results above ... we may choose the "best" model | |
| ## and then evaluate more generalized performance using the test set |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment