Skip to content

Instantly share code, notes, and snippets.

@topepo
Created July 10, 2020 01:37
Show Gist options
  • Save topepo/f20770bd916463313bb379c73ac52e5b to your computer and use it in GitHub Desktop.
Save topepo/f20770bd916463313bb379c73ac52e5b to your computer and use it in GitHub Desktop.
first of two S3 methods for plotting PCA components
#' @param x A prepped recipe or fitted workflow that uses a recipe. The recipe
#' must have used at least one `step_pca()`.
#' @param id A single numeric or character value that is used to pick the step
#' with the PCA results. If a single `step_pca()` was used, this argument is
#' ignored. *Note*: if used, `id` must be named.
#' @param ... An optional series of conditional statements used to filter the
#' PCA data before plotting. See Details below.
#' @examples
#' library(recipes)
#' library(parsnip)
#' library(workflows)
#' library(ggplot2)
#'
#' data("Chicago", package = "modeldata")
#'
#' theme_set(theme_minimal())
#'
#' ## -----------------------------------------------------------------------------
#'
#' train_pca <-
#' recipe(ridership ~ ., data = Chicago %>% dplyr::select(1:21)) %>%
#' step_center(all_predictors()) %>%
#' step_scale(all_predictors()) %>%
#' step_pca(all_predictors())
#'
#' # or when used in a workflow
#' lm_workflow <-
#' workflow() %>%
#' add_model(linear_reg() %>% set_engine("lm")) %>%
#' add_recipe(train_pca)
#'
#' ## -----------------------------------------------------------------------------
#'
#' train_pca <- prep(train_pca)
#'
#' component_plot(train_pca, component_number <= 3)
#'
#' component_plot(train_pca, component_number <= 3, value > 0)
#'
#' ## -----------------------------------------------------------------------------
#'
#' lm_workflow <- lm_workflow %>% fit(data = Chicago)
#'
#' component_plot(lm_workflow, component_number <= 3)
#'
component_plot <- function(x, ...) {
UseMethod("component_plot")
}
check_recipe_for_pca <- function(x, id) {
rec_steps <- tidy(x)
is_pca <- rec_steps$type == "pca"
if (sum(is_pca) == 0) {
rlang::stop(
paste("The recipe does not appear to use step_pca(). Use the `tidy()`",
"method for more details.")
)
}
if (sum(is_pca) > 1) {
if (is.null(id)) {
rlang::stop(
paste("The recipe appears to have multiple PCA steps. Use the 'id'",
"argument to pick one. The `tidy()` can list the availible",
"steps.")
)
} else {
if (length(id) != 1 | all(!is.numeric(id)) | all(!is.character(id))) {
rlang::stop(
paste("'id' should be either a single character string or single",
"numeric value.")
)
}
if (is.numeric(id)) {
pca_id <- rec_steps$id[rec_steps$number == id]
if (rec_steps$type[rec_steps$number == id] != "pca") {
rlang::stop(
paste0("'id' value ", id, " does not appear to correspond to a ",
"PCA step. The `tidy()` can list the availible steps.")
)
}
} else {
pca_id <- rec_steps$id[rec_steps$id == id]
if (rec_steps$type[rec_steps$id == id] != "pca") {
rlang::stop(
paste0("'id' value ", id, " does not appear to correspond to a ",
"PCA step. The `tidy()` can list the availible steps.")
)
}
}
}
} else {
pca_id <- rec_steps$id[rec_steps$type == "pca"]
}
if (!rec_steps$trained[rec_steps$id == pca_id]) {
rlang::stop("Please `prep()` the recipe.")
}
pca_id
}
filter_pca_data <- function(x, ...) {
filters <- rlang::enquos(...)
if (!rlang::is_empty(filters)) {
x <- dplyr::filter(x, !!!filters)
}
x
}
component_plot.recipe <- function(x, ..., id = NULL) {
pca_id <- check_recipe_for_pca(x, id)
pca_vals <- tidy(x, id = pca_id)
# Convert component label to number
pca_vals$component_number <- as.numeric(gsub("[[:alpha:]]", "", pca_vals$component))
# Optional filtering
pca_vals <- filter_pca_data(pca_vals, ...)
# Reorder component labels
pca_vals$component <- forcats::fct_inorder(pca_vals$component)
pca_rng <- max(abs(pca_vals$value))
pca_rng <- c(-pca_rng, pca_rng)
pca_vals %>%
dplyr::mutate(component = component) %>%
ggplot2::ggplot(ggplot2::aes(value, terms, fill = terms)) +
ggplot2::geom_col(show.legend = FALSE) +
ggplot2::facet_wrap( ~ component) +
ggplot2::labs(y = NULL, x = "Coefficient Value") +
ggplot2::xlim(pca_rng)
}
component_plot.workflow <- function(x, ..., id = NULL) {
x <- workflows::pull_workflow_prepped_recipe(x)
component_plot(x, ..., id = id)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment