Created
October 24, 2017 07:26
-
-
Save thomasp85/e68b8cf62b9500afde8d832661c7d3c3 to your computer and use it in GitHub Desktop.
Trim all unnecessary data from model objects
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(future) | |
trim_model <- function(model, predictor = predict, ..., ignore_warnings = TRUE) { | |
# Cache the correct output | |
true_pred <- predictor(model, ...) | |
# Treat prediction warnings as errors? | |
if (!ignore_warnings) { | |
old_ops <- options(warn = 2) | |
on.exit(options(old_ops)) | |
} | |
# Define Recursion function | |
trimmer <- function(model, path = integer()) { | |
if (length(path) != 0) { # Not the top level object | |
# Try a model without this element | |
temp_model <- model | |
temp_model[[path]] <- NULL | |
# Run in another process to avoid crashing main session | |
try_pred <- multicore(predictor(temp_model, ...)) | |
# No change if removed - set to NULL | |
if (identical(try(value(try_pred), silent = TRUE), true_pred)) return(NULL) | |
# Change if removed - dig into it | |
element <- model[[path]] | |
} else { # Top level model object - dig into it | |
element <- model | |
} | |
# Only dig into list elements, ignoring data.frames | |
if (!(is.list(element) && !is.data.frame(element))) return(element) | |
# Try every subelement | |
element_class <- class(element) | |
element <- unclass(element) | |
for (i in rev(seq_along(element))) { # Loop backwards to avoid changing index | |
if (is.null(element[[i]])) next | |
element[[i]] <- trimmer(model, c(path, i)) # Trim element | |
# Update model to match trimmed element | |
if (length(path) == 0) model <- structure(element, class = element_class) | |
else model[[path]] <- structure(element, class = element_class) | |
} | |
# Reclass the list | |
class(element) <- element_class | |
element | |
} | |
# Trim | |
trimmer(model) | |
} |
Does this work on train
objects from the caret
package?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Default will allow the trimmed model to throw warnings during prediction as long as it returns the same output. set
ignore_warnings = FALSE
to change this...There is no guarantee that other methods than the given predictor continue to work on the trimmed model...