Created
July 10, 2020 01:37
-
-
Save topepo/f20770bd916463313bb379c73ac52e5b to your computer and use it in GitHub Desktop.
first of two S3 methods for plotting PCA components
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' @param x A prepped recipe or fitted workflow that uses a recipe. The recipe | |
#' must have used at least one `step_pca()`. | |
#' @param id A single numeric or character value that is used to pick the step | |
#' with the PCA results. If a single `step_pca()` was used, this argument is | |
#' ignored. *Note*: if used, `id` must be named. | |
#' @param ... An optional series of conditional statements used to filter the | |
#' PCA data before plotting. See Details below. | |
#' @examples | |
#' library(recipes) | |
#' library(parsnip) | |
#' library(workflows) | |
#' library(ggplot2) | |
#' | |
#' data("Chicago", package = "modeldata") | |
#' | |
#' theme_set(theme_minimal()) | |
#' | |
#' ## ----------------------------------------------------------------------------- | |
#' | |
#' train_pca <- | |
#' recipe(ridership ~ ., data = Chicago %>% dplyr::select(1:21)) %>% | |
#' step_center(all_predictors()) %>% | |
#' step_scale(all_predictors()) %>% | |
#' step_pca(all_predictors()) | |
#' | |
#' # or when used in a workflow | |
#' lm_workflow <- | |
#' workflow() %>% | |
#' add_model(linear_reg() %>% set_engine("lm")) %>% | |
#' add_recipe(train_pca) | |
#' | |
#' ## ----------------------------------------------------------------------------- | |
#' | |
#' train_pca <- prep(train_pca) | |
#' | |
#' component_plot(train_pca, component_number <= 3) | |
#' | |
#' component_plot(train_pca, component_number <= 3, value > 0) | |
#' | |
#' ## ----------------------------------------------------------------------------- | |
#' | |
#' lm_workflow <- lm_workflow %>% fit(data = Chicago) | |
#' | |
#' component_plot(lm_workflow, component_number <= 3) | |
#' | |
component_plot <- function(x, ...) { | |
UseMethod("component_plot") | |
} | |
check_recipe_for_pca <- function(x, id) { | |
rec_steps <- tidy(x) | |
is_pca <- rec_steps$type == "pca" | |
if (sum(is_pca) == 0) { | |
rlang::stop( | |
paste("The recipe does not appear to use step_pca(). Use the `tidy()`", | |
"method for more details.") | |
) | |
} | |
if (sum(is_pca) > 1) { | |
if (is.null(id)) { | |
rlang::stop( | |
paste("The recipe appears to have multiple PCA steps. Use the 'id'", | |
"argument to pick one. The `tidy()` can list the availible", | |
"steps.") | |
) | |
} else { | |
if (length(id) != 1 | all(!is.numeric(id)) | all(!is.character(id))) { | |
rlang::stop( | |
paste("'id' should be either a single character string or single", | |
"numeric value.") | |
) | |
} | |
if (is.numeric(id)) { | |
pca_id <- rec_steps$id[rec_steps$number == id] | |
if (rec_steps$type[rec_steps$number == id] != "pca") { | |
rlang::stop( | |
paste0("'id' value ", id, " does not appear to correspond to a ", | |
"PCA step. The `tidy()` can list the availible steps.") | |
) | |
} | |
} else { | |
pca_id <- rec_steps$id[rec_steps$id == id] | |
if (rec_steps$type[rec_steps$id == id] != "pca") { | |
rlang::stop( | |
paste0("'id' value ", id, " does not appear to correspond to a ", | |
"PCA step. The `tidy()` can list the availible steps.") | |
) | |
} | |
} | |
} | |
} else { | |
pca_id <- rec_steps$id[rec_steps$type == "pca"] | |
} | |
if (!rec_steps$trained[rec_steps$id == pca_id]) { | |
rlang::stop("Please `prep()` the recipe.") | |
} | |
pca_id | |
} | |
filter_pca_data <- function(x, ...) { | |
filters <- rlang::enquos(...) | |
if (!rlang::is_empty(filters)) { | |
x <- dplyr::filter(x, !!!filters) | |
} | |
x | |
} | |
component_plot.recipe <- function(x, ..., id = NULL) { | |
pca_id <- check_recipe_for_pca(x, id) | |
pca_vals <- tidy(x, id = pca_id) | |
# Convert component label to number | |
pca_vals$component_number <- as.numeric(gsub("[[:alpha:]]", "", pca_vals$component)) | |
# Optional filtering | |
pca_vals <- filter_pca_data(pca_vals, ...) | |
# Reorder component labels | |
pca_vals$component <- forcats::fct_inorder(pca_vals$component) | |
pca_rng <- max(abs(pca_vals$value)) | |
pca_rng <- c(-pca_rng, pca_rng) | |
pca_vals %>% | |
dplyr::mutate(component = component) %>% | |
ggplot2::ggplot(ggplot2::aes(value, terms, fill = terms)) + | |
ggplot2::geom_col(show.legend = FALSE) + | |
ggplot2::facet_wrap( ~ component) + | |
ggplot2::labs(y = NULL, x = "Coefficient Value") + | |
ggplot2::xlim(pca_rng) | |
} | |
component_plot.workflow <- function(x, ..., id = NULL) { | |
x <- workflows::pull_workflow_prepped_recipe(x) | |
component_plot(x, ..., id = id) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment