Last active
March 29, 2021 09:36
-
-
Save mathzero/bdedbc329b9d5dac70d84f4ba0c64023 to your computer and use it in GitHub Desktop.
Permutation function to calculate variable importance in xgboost
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# xgboost permutation function | |
#' This function takes an XGBoost model and some X and y data | |
#' (ideally this would be unseen holdout data but you could also use the training data) | |
#' and returns a data frame with an estimation of the contribution that each variable makes to the overall AUC | |
#' can take a long time to run with a large data set – nperm can be reduced to reduce compute time | |
PermuteImportXBG <- function(model, X, y, nperm = 100){ | |
predictors=model$feature_names | |
perm_df= data.frame(matrix(data=NA, nrow=nperm, | |
ncol = length(predictors), | |
dimnames = list(1:nperm,predictors))) | |
preds_full=predict(model, X) | |
auc_full=pROC::auc(as.vector(y), as.vector(preds_full))[[1]] | |
for(i in 1:length(predictors)){ | |
print(paste("Permuting", predictors[[i]])) | |
for(j in 1:nperm){ | |
newX=X | |
newX[,predictors[[i]]]=gtools::permute(X[,predictors[[i]]]) | |
preds_new=predict(model, newX) | |
auc_new=pROC::auc(as.vector(y), as.vector(preds_new))[[1]] | |
perm_df[j,i] = auc_new | |
} | |
} | |
results.df = data.frame(predictor=predictors, | |
mean_AUC_permute = NA_real_, | |
SD_AUC_permute = NA_real_, | |
mean_AUC_permute_delta = NA_real_) | |
results.df$mean_AUC_permute <- colMeans(perm_df) | |
results.df$SD_AUC_permute <- sapply(perm_df, sd) | |
results.df$mean_AUC_permute_delta =auc_full-results.df$mean_AUC_permute | |
results.df$mean_AUC_permute_delta_lower <-results.df$mean_AUC_permute_delta- 1.96*( | |
results.df$SD_AUC_permute / sqrt(nperm)) | |
results.df$mean_AUC_permute_delta_upper <-results.df$mean_AUC_permute_delta+ 1.96*( | |
results.df$SD_AUC_permute / sqrt(nperm)) | |
return(results.df) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment