Created
August 11, 2016 06:24
-
-
Save slopp/82272f00993c28249816a0024f0d60e6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' Spark ML - Binary Classifier Area under ROC | |
#' | |
#' @param predicted_tbl_spark The result of running sdf_predict | |
#' @param label A character string specifying which column contains the true, indexed labels (0 / 1) | |
#' @param score A characger string specifying which column contains the scored probability of a 1 | |
#' | |
#' @return The area under the ROC curve. | |
#' @export | |
#' | |
ml_auc_roc <- function(predicted_tbl_spark, label, score){ | |
df <- spark_dataframe(predicted_tbl_spark) | |
sc <- spark_connection(df) | |
envir <- new.env(parent = emptyenv()) | |
tdf <- df %>% ml_prepare_dataframe(response = label, feature = c(score, score), envir = envir) | |
auc_roc <- invoke_new(sc, "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator") %>% | |
invoke("setLabelCol", envir$response) %>% | |
invoke("setRawPredictionCol", envir$features) %>% | |
invoke("setMetricName", "areaUnderROC") %>% | |
invoke("evaluate", tdf) | |
return(auc_roc) | |
} | |
#' Spark ML - Classifier Accuracy | |
#' | |
#' @param predicted_tbl_spark The result of running sdf_predict | |
#' @param label A string specifying the column that contains the true, indexed label. Support for binary and multi-class labels, column should be of double type (use as.double) | |
#' @param predicted_lbl A string specifying the column that contains the predicted label NOT the scored probability. Support for binary and multi-class labels, column should be of double type (use as.double) | |
#' | |
#' @return | |
#' @export | |
#' | |
#' @examples | |
ml_accuracy <- function(predicted_tbl_spark, label, predicted_lbl){ | |
df <- spark_dataframe(predicted_tbl_spark) | |
sc <- spark_connection(df) | |
accuracy <- invoke_new(sc, "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator") %>% | |
invoke("setLabelCol", label) %>% | |
invoke("setPredictionCol", predicted_lbl) %>% | |
invoke("setMetricName", "accuracy") %>% | |
invoke("evaluate", df) | |
return(accuracy) | |
} | |
#' Spark ML - Feature Importance for Tree Models | |
#' | |
#' @param model An ml_model object, support for decision trees (>1.5.0), random forest (>2.0.0), GBT (>2.0.0) | |
#' | |
#' @return A sorted data frame with feature labels and their relative importance. | |
#' @export | |
#' | |
#' @examples | |
ml_tree_feature_importance <- function(model){ | |
supported <- c("ml_model_gradient_boosted_trees", | |
"ml_model_decision_tree", | |
"ml_model_random_forest") | |
if ( !(class(model)[1] %in% supported)) { | |
stop("Supported models include: ", paste(supported, collapse = ", ")) | |
} | |
if (class(model) != "ml_model_decision_tree") spark_require_version(sc, "2.0.0") | |
importance <- invoke(model$.model,"featureImportances") %>% | |
invoke("toArray") %>% | |
cbind(model$features) %>% | |
as.data.frame() | |
colnames(importance) <- c("importance", "feature") | |
importance %>% arrange(desc(importance)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment