Last active
February 10, 2017 18:16
-
-
Save artemklevtsov/aad8368d94a54b24cc88aac703859c6c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plot.roc <- function(x, ...) { | |
ind <- which.max(x$tpr + 1 - x$fpr) | |
opar <- par(pty = "s") | |
on.exit(par(opar)) | |
ticks <- seq(from = 0, to = 1, by = 0.1) | |
plot(x = x$fpr, y = x$tpr, | |
xlim = c(0, 1), ylim = c(0, 1), | |
type = "s", lwd = 2, | |
xlab = "False Posisitive Rate", | |
ylab = "True Positive Rate", | |
axes = FALSE) | |
axis(side = 2, at = ticks) | |
axis(side = 1, at = ticks) | |
abline(h = ticks, v = ticks, col = "lightgray", lty = "dotted", lwd = 1) | |
abline(a = 0, b = 1) | |
abline(a = x$tpr[ind] - x$fpr[ind], b = 1, lty = "dashed") | |
text(x = 0.5, y = 0.5, labels = "Random Gues", col = "gray50", pos = 1, srt = 45) | |
metrics <- c(AUC = attr(x, "auc"), Gini = attr(x, "gini"), KS = attr(x, "ks")) | |
metrics <- sprintf("%s: %.3f", names(metrics), metrics) | |
legend("bottomright", legend = metrics, bty = "n") | |
box() | |
title(main = "ROC curve") | |
invisible(NULL) | |
} | |
library(ggplot2) | |
library(scales) | |
autoplot.roc <- function(x, metrics = TRUE) { | |
ind <- which.max(x$tpr + 1 - x$fpr) | |
p <- ggplot(x, aes(x = fpr, y = tpr)) + | |
geom_path() + | |
scale_x_continuous("False Positive Rate", | |
labels = percent_format(), limits = c(0, 1)) + | |
scale_y_continuous("True Positive Rate", | |
labels = percent_format(), limits = c(0, 1)) + | |
coord_equal() + | |
geom_abline(intercept = 0, slope = 1, color = "gray50", linetype = "dashed") + | |
annotate("text", label = "Rangom Guess", color = "gray50", | |
x = 0.5, y = 0.5, vjust = 1.5, angle = 45) + | |
geom_abline(intercept = x$tpr[ind] - x$fpr[ind], slope = 1, | |
color = "gray50", linetype = "dashed") + | |
labs(title = "ROC Curve") | |
if (metrics) { | |
metrics <- c(AUC = attr(x, "auc"), Gini = attr(x, "gini"), KS = attr(x, "ks")) | |
metrics <- sprintf("%s: %.3f", names(metrics), metrics) | |
matrics <- paste(metrics, collapse = "\n") | |
p <- p + annotate("text", label = matrics, x = 1, y = 0, hjust = 1, vjust = 0) | |
} | |
p | |
} | |
library(highcharter) | |
hchart.roc <- function(x) { | |
nm <- c("Cutoff", "Sensitivity", "Specificity", "Accuracy") | |
vl <- sprintf("{point.%s:.3f}", c("cutoff", "tpr", "spec", "acc")) | |
tltip <- tooltip_table(nm, vl) | |
x$spec <- 1 - x$fpr | |
metrics <- c(AUC = attr(x, "auc"), Gini = attr(x, "gini"), KS = attr(x, "ks")) | |
metrics <- paste(sprintf("%s: %.3f", names(metrics), metrics), collapse = "<br />") | |
hc <- highchart() %>% | |
hc_add_series(x, type = "area", hcaes(x = fpr, y = tpr), name = "ROC Curve") %>% | |
hc_tooltip(useHTML = TRUE, headerFormat = "", pointFormat = tltip) %>% | |
hc_xAxis(title = list(text = "False Positive Rate (1 - Specificity)"), | |
min = 0, max = 1, tickInterval = 0.1, gridLineWidth = 1) %>% | |
hc_yAxis(title = list(text = "True Positive Rate (Sensitivity)"), | |
min = 0, max = 1, tickInterval = 0.1, gridLineWidth = 1) %>% | |
hc_add_annotation(xValue = 0.95, yValue = 0.05, anchorX = "right", anchorY = "bottom", | |
title = list(text = metrics)) %>% | |
hc_title(text = "ROC Curve") | |
hc | |
} | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// [[Rcpp::plugins("cpp11")]] | |
#include <Rcpp.h> | |
using namespace Rcpp; | |
template <typename T> | |
std::vector<size_t> order(const T& x) { | |
std::vector<size_t> ind(x.size()); | |
std::iota(ind.begin(), ind.end(), 0); | |
auto comparator = [&x](const size_t& a, const size_t& b) { return x[a] > x[b]; }; | |
std::sort(ind.begin(), ind.end(), comparator); | |
return ind; | |
} | |
// [[Rcpp::export]] | |
List roc_impl(const NumericVector& pred, const IntegerVector& target, | |
int positive, bool find_cutoffs = true) { | |
if (pred.size() != target.size()) | |
stop("Sizes of the input vectors must be equal."); | |
size_t n_in = pred.size(); | |
std::vector<size_t> ind = order(pred); | |
std::vector<double> cutoff; | |
std::vector<double> tpr; | |
std::vector<double> fpr; | |
cutoff.reserve(n_in + 1); | |
tpr.reserve(n_in + 1); | |
fpr.reserve(n_in + 1); | |
cutoff.push_back(R_PosInf); | |
tpr.push_back(0); | |
fpr.push_back(0); | |
size_t tp = 0; | |
size_t fp = 0; | |
int last = -1; | |
for (size_t i = 0; i < n_in; ++i) { | |
size_t ii = ind[i]; | |
if (pred[ii] == NA_REAL || target[ii] == NA_INTEGER || pred[ii] == cutoff.back()) | |
continue; | |
target[ii] == positive ? ++tp : ++fp; | |
if (find_cutoffs && target[ii] == last && i < n_in - 1) | |
continue; | |
tpr.push_back(tp); | |
fpr.push_back(fp); | |
cutoff.push_back(pred[ii]); | |
last = target[ii]; | |
} | |
size_t n = tp + fp; | |
size_t n_pos = tp; | |
size_t n_neg = n - n_pos; | |
size_t n_out = cutoff.size(); | |
std::vector<double> acc; | |
std::vector<double> f1; | |
acc.reserve(n_out); | |
f1.reserve(n_out); | |
double auc = 0; | |
double ks = 0; | |
for (size_t i = 0; i < n_out; ++i) { | |
acc.push_back((tpr[i] + n_neg - fpr[i]) / n); | |
f1.push_back((2.0 * tpr[i]) / (2.0 * tpr[i] + n_pos - tpr[i] + fpr[i])); | |
tpr[i] = tpr[i] / n_pos; | |
fpr[i] = fpr[i] / n_neg; | |
double tmp = std::abs(tpr[i] - fpr[i]); | |
ks = tmp > ks ? tmp : ks; | |
if (i > 1) | |
auc += (tpr[i] + tpr[i - 1]) * (fpr[i] - fpr[i - 1]); | |
} | |
auc = auc * 0.5; | |
auc = auc < 0.5 ? 1 - auc : auc; | |
List res = List::create(cutoff, tpr, fpr, acc, f1); | |
res.attr("names") = CharacterVector::create("cutoff", "tpr", "fpr", "acc", "f1"); | |
res.attr("class") = CharacterVector::create("roc", "tbl_df", "tbl", "data.frame"); | |
res.attr("row.names") = IntegerVector::create(NA_INTEGER, -(n_out)); | |
res.attr("n_pos") = n_pos; | |
res.attr("n_neg") = n_neg; | |
res.attr("auc") = auc; | |
res.attr("gini") = 2.0 * auc - 1; | |
res.attr("ks") = ks; | |
return res; | |
} | |
template <int RTYPE> | |
IntegerVector as_factor_impl(const Vector<RTYPE> & x) { | |
Vector<RTYPE> lvls = sort_unique(x); | |
IntegerVector out = match(x, lvls); | |
out.attr("levels") = as<CharacterVector>(lvls); | |
out.attr("class") = "factor"; | |
return out; | |
} | |
// [[Rcpp::export]] | |
SEXP as_factor(SEXP x) { | |
if (Rf_isFactor(x)) return x; | |
switch( TYPEOF(x) ) { | |
case INTSXP: return as_factor_impl<INTSXP>(x); | |
case REALSXP: return as_factor_impl<REALSXP>(x); | |
case STRSXP: return as_factor_impl<STRSXP>(x); | |
default: stop("Unsupported type."); | |
} | |
return R_NilValue; | |
} | |
/*** R | |
#' @title Build a ROC curve | |
#' @param pred a numeric vector of the same length than response, containing | |
#' the predicted value of each observation. | |
#' @param target a factor, numeric or character vector of responses, typically | |
#' encoded with 0 (controls) and 1 (cases). Only two classes can be used in a ROC curve. | |
roc <- function(pred, target, pos.lab = NULL) { | |
target <- as_factor(target) | |
stopifnot(nlevels(target) == 2) | |
pos <- ifelse(is.null(pos.lab), 2L, which.min(pos.lab %in% levels(target))) | |
roc_impl(pred, target, pos, TRUE) | |
} | |
data("GermanCredit", package = "caret") | |
target <- GermanCredit$Class <- as.integer(GermanCredit$Class == "Bad") | |
score <- fitted.values(glm(Class ~ ., data = GermanCredit, family = binomial)) | |
r <- roc(score, target) | |
str(r) | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment