Created
October 24, 2017 07:26
-
-
Save thomasp85/e68b8cf62b9500afde8d832661c7d3c3 to your computer and use it in GitHub Desktop.
Trim all unnecessary data from model objects
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(future) | |
trim_model <- function(model, predictor = predict, ..., ignore_warnings = TRUE) { | |
# Cache the correct output | |
true_pred <- predictor(model, ...) | |
# Treat prediction warnings as errors? | |
if (!ignore_warnings) { | |
old_ops <- options(warn = 2) | |
on.exit(options(old_ops)) | |
} | |
# Define Recursion function | |
trimmer <- function(model, path = integer()) { | |
if (length(path) != 0) { # Not the top level object | |
# Try a model without this element | |
temp_model <- model | |
temp_model[[path]] <- NULL | |
# Run in another process to avoid crashing main session | |
try_pred <- multicore(predictor(temp_model, ...)) | |
# No change if removed - set to NULL | |
if (identical(try(value(try_pred), silent = TRUE), true_pred)) return(NULL) | |
# Change if removed - dig into it | |
element <- model[[path]] | |
} else { # Top level model object - dig into it | |
element <- model | |
} | |
# Only dig into list elements, ignoring data.frames | |
if (!(is.list(element) && !is.data.frame(element))) return(element) | |
# Try every subelement | |
element_class <- class(element) | |
element <- unclass(element) | |
for (i in rev(seq_along(element))) { # Loop backwards to avoid changing index | |
if (is.null(element[[i]])) next | |
element[[i]] <- trimmer(model, c(path, i)) # Trim element | |
# Update model to match trimmed element | |
if (length(path) == 0) model <- structure(element, class = element_class) | |
else model[[path]] <- structure(element, class = element_class) | |
} | |
# Reclass the list | |
class(element) <- element_class | |
element | |
} | |
# Trim | |
trimmer(model) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Does this work on
train
objects from thecaret
package?