Last active
July 8, 2020 17:11
-
-
Save anirudhjayaraman/aaec12f88919f22c1029cc09f221b7c9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
remove_missing_levels <- function(fit, test_data) { | |
library(magrittr) | |
# https://stackoverflow.com/a/39495480/4185785 | |
# drop empty factor levels in test data | |
test_data %>% | |
droplevels() %>% | |
as.data.frame() -> test_data | |
# 'fit' object structure of 'lm' and 'glmmPQL' is different so we need to | |
# account for it | |
if (any(class(fit) == "glmmPQL")) { | |
# Obtain factor predictors in the model and their levels | |
factors <- (gsub("[-^0-9]|as.factor|\\(|\\)", "", | |
names(unlist(fit$contrasts)))) | |
# do nothing if no factors are present | |
if (length(factors) == 0) { | |
return(test_data) | |
} | |
map(fit$contrasts, function(x) names(unmatrix(x))) %>% | |
unlist() -> factor_levels | |
factor_levels %>% str_split(":", simplify = TRUE) %>% | |
extract(, 1) -> factor_levels | |
model_factors <- as.data.frame(cbind(factors, factor_levels)) | |
} else { | |
# Obtain factor predictors in the model and their levels | |
factors <- (gsub("[-^0-9]|as.factor|\\(|\\)", "", | |
names(unlist(fit$xlevels)))) | |
# do nothing if no factors are present | |
if (length(factors) == 0) { | |
return(test_data) | |
} | |
factor_levels <- unname(unlist(fit$xlevels)) | |
model_factors <- as.data.frame(cbind(factors, factor_levels)) | |
} | |
# Select column names in test data that are factor predictors in | |
# trained model | |
predictors <- names(test_data[names(test_data) %in% factors]) | |
# For each factor predictor in your data, if the level is not in the model, | |
# set the value to NA | |
for (i in 1:length(predictors)) { | |
found <- test_data[, predictors[i]] %in% model_factors[ | |
model_factors$factors == predictors[i], ]$factor_levels | |
if (any(!found)) { | |
# track which variable | |
var <- predictors[i] | |
# set to NA | |
test_data[!found, predictors[i]] <- NA | |
# drop empty factor levels in test data | |
test_data %>% | |
droplevels() -> test_data | |
# issue warning to console | |
message(sprintf(paste0("Setting missing levels in '%s', only present", | |
" in test data but missing in train data,", | |
" to 'NA'."), | |
var)) | |
} | |
} | |
return(test_data) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment