Skip to content

Instantly share code, notes, and snippets.

@AdamSpannbauer
Last active September 11, 2024 11:50
Show Gist options
  • Save AdamSpannbauer/72af0e090c344b21e9de307fad5380ad to your computer and use it in GitHub Desktop.
Save AdamSpannbauer/72af0e090c344b21e9de307fad5380ad to your computer and use it in GitHub Desktop.
Helper to make a conditionally formatted table of kmeans centroids with ggplot2
library(ggplot2)
library(dplyr)
library(tidyr)
plot_centroids_table <- function(kmeans_object) {
n_clusters <- nrow(kmeans_object$centers)
plot_df <- data.frame(t(kmeans_object$centers))
names(plot_df) <- paste("Cluster", 1:n_clusters)
plot_df$feature_name <- rownames(plot_df)
plot_df <- pivot_longer(plot_df, cols = -feature_name)
ggplot(plot_df, aes(x = name, y = feature_name, fill = value)) +
geom_tile() +
geom_text(aes(label = round(value, 2)), color = "white") +
labs(x = "", y = "")
}
plot_mean_by_label_table <- function(your_data, label_column,
summary_func = mean,
drop_cols = c()) {
your_data[, drop_cols] <- NULL
your_data$group_name <- paste("Cluster", your_data[[label_column]])
your_data[[label_column]] <- NULL
plot_df <- your_data |>
group_by(group_name) |>
summarise_all(summary_func) |>
pivot_longer(-group_name)
ggplot(plot_df, aes(x = group_name, y = name, fill = value)) +
geom_tile() +
geom_text(aes(label = round(value, 2)), color = "white") +
labs(x = "", y = "")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment