|
# minimal example |
|
library(randomForest) |
|
library(ggplot2) |
|
library(dplyr) |
|
library(purrr) |
|
library(colormap) |
|
library(tree) |
|
library(plotrix) |
|
library(cowplot) |
|
library(gridGraphics) |
|
library(magick) |
|
library(forcats) |
|
|
|
# source repartree package for plotting individual trees generated by RF |
|
# clone repo at https://github.com/richpauloo/reprtree |
|
# and change the file path below to the directory containing the R files |
|
invisible( |
|
lapply( |
|
list.files("/Users/richpauloo/GitHub/reprtree/R", full.names = TRUE), |
|
source |
|
) |
|
) |
|
|
|
# example data == mtcars |
|
df <- mtcars |
|
|
|
# number of trees to grow in random forest |
|
nn <- 500 |
|
|
|
# function to run nn CART models (single tree) |
|
run_rf <- function(rand_seed){ |
|
set.seed(rand_seed) |
|
one_tr = randomForest(mpg ~ ., |
|
data = df, |
|
importance = TRUE, |
|
ntree = 1) |
|
return(one_tr) |
|
} |
|
|
|
# list to store output of each model |
|
l <- lapply(1:nn, run_rf) |
|
|
|
# number of predictors in RF mod |
|
npred <- length(names(l[[1]]$forest$xlevels)) |
|
|
|
# extract importance of each CART model, |
|
impdf <- map(l, importance) %>% |
|
map(as.data.frame) %>% |
|
map( ~ { .$var = rownames(.); rownames(.) <- NULL; return(.) } ) %>% |
|
bind_rows() %>% |
|
mutate(tree_num = rep(1:nn, each = npred)) # add tree number |
|
|
|
# summarised var imp |
|
tot_mse <- group_by(impdf, var) %>% |
|
summarise(`%IncMSE` = mean(`%IncMSE`)) %>% |
|
arrange(-`%IncMSE`) |
|
|
|
# ranked variables |
|
rv <- tot_mse$var |
|
impdf$var <- factor(impdf$var, levels = rv) |
|
|
|
# vector of trees to plot |
|
# here I plot every 10 trees for speed, but this can be changed |
|
plt_vec <- c(1, seq(10, nn, 10)) |
|
|
|
# initalize lists for: varimp, trees, plot titles, and combined plots |
|
pl <- tl <- pt <- bp <- vector("list", length = length(plt_vec)) |
|
|
|
for(i in seq_along(plt_vec)){ |
|
|
|
# cumulative variable importance with each tree's addition |
|
pl[[i]] <- filter(impdf, tree_num %in% 1:plt_vec[i]) %>% |
|
group_by(var) %>% |
|
summarise(mse = mean(`%IncMSE`)) %>% |
|
ggplot(aes(forcats::fct_rev(var), mse, fill=var)) + |
|
geom_col() + |
|
coord_flip(ylim = c(0, max(tot_mse$`%IncMSE`))) + |
|
scale_fill_viridis_d() + |
|
labs(x = "Variable", y = "Importance (% Inc MSE)", fill = "Variable", |
|
title = paste0("Tree ", plt_vec[i])) + |
|
theme_minimal() + |
|
theme(legend.position = "bottom", |
|
plot.title = element_text(size=25)) |
|
|
|
# make tree plots |
|
plot.getTree(l[[plt_vec[i]]], k = 1, npred = npred, rv = rv) |
|
tl[[i]] <- recordPlot() |
|
|
|
# make plot titles |
|
pt[[i]] <- ggdraw() + |
|
draw_label( |
|
paste0("Tree ", plt_vec[i]), |
|
fontface = 'bold', |
|
x = 0, |
|
hjust = 0 |
|
) + |
|
theme( |
|
# add margin on the left of the drawing canvas, |
|
# so title is aligned with left edge of first plot |
|
plot.margin = margin(0, 0, 0, 7) |
|
) |
|
|
|
# combine all plots with title |
|
bp[[i]] <- plot_grid(pl[[i]], tl[[i]]) |
|
|
|
} |
|
|
|
|
|
# use magick to turn plots into a GIF. |
|
# WARNING: magick doens't handle hundreds of plots well in my experience |
|
# and it may be better to print them into a single PDF, then render the |
|
# GIF elsewhere. Also beware of temporary files that magick creates... |
|
# Also, this animation may be too large to fit in your viewer, so |
|
# be sure to expand it! |
|
# img <- image_graph(1000, 600, res = 96) |
|
# for(i in seq_along(plt_vec)){ print( bp[[i]] ) } |
|
# dev.off() |
|
# animation <- image_animate(img, fps = 2) |
|
# print(animation) |
|
# |
|
# # save to working directory |
|
# image_write(animation, "anim.gif") |
|
|
|
# uncomment and run to print to PDF and makethe gif elsewhere, |
|
# like https://ezgif.com/maker |
|
pdf("all.pdf", width = 12, height = 7) |
|
invisible(lapply(bp, print)) |
|
dev.off() |
Hi - great code, thanks for sharing.
one question: I'm having issues with plot.getTree - is this part of the reprtree package? Could be my mistake but cant seem to run the function.
Thanks