Created
February 17, 2014 07:21
-
-
Save hammady/9046168 to your computer and use it in GitHub Desktop.
After sourcing this patch, you can use "wss" as a new measure in plotting performance curves, same like fp, tp, rec, auc, ... You should have ROCR installed and imported first
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Patch ROCR R package to support plotting of WSS (Work Saved over Sampling) | |
# Author: Hossam Hammady ([email protected]) | |
# Organization: Qatar Computing Research Institute; Data Analytics Group (http://da.qcri.qa) | |
# Date: 17-Feb-2014 | |
# License: MIT | |
# Description: After sourcing this patch, you can use "wss" as a new measure in | |
# .. plotting performance curves, same like fp, tp, rec, auc, ... | |
# .. You should have ROCR installed and imported first | |
original.define.environments <- .define.environments | |
.define.environments <- function() | |
{ | |
# get original environments | |
envir.list <- original.define.environments() | |
long.unit.names = envir.list$long.unit.names | |
function.names = envir.list$function.names | |
obligatory.x.axis = envir.list$obligatory.x.axis | |
optional.arguments = envir.list$optional.arguments | |
default.values = envir.list$default.values | |
.performance.wss <- function(predictions, labels, cutoffs, | |
fp, tp, fn, tn, n.pos, n.neg, n.pos.pred, n.neg.pred) # append any optional arguments | |
{ | |
list(cutoffs, (tn+fn)/(n.pos+n.neg) - 1 + tp/(tp+fn)) | |
} | |
assign("wss", "Work Saved over Random Sampling", | |
envir = long.unit.names) | |
assign("wss", .performance.wss, | |
envir = function.names) | |
list( | |
long.unit.names = long.unit.names, | |
function.names = function.names, | |
obligatory.x.axis = obligatory.x.axis, | |
optional.arguments = optional.arguments, | |
default.values = default.values | |
) | |
} | |
# copied from original function | |
performance <- function (prediction.obj, measure, x.measure = "cutoff", ...) | |
{ | |
envir.list <- .define.environments() | |
long.unit.names <- envir.list$long.unit.names | |
function.names <- envir.list$function.names | |
obligatory.x.axis <- envir.list$obligatory.x.axis | |
optional.arguments <- envir.list$optional.arguments | |
default.values <- envir.list$default.values | |
if (class(prediction.obj) != "prediction" || !exists(measure, | |
where = long.unit.names, inherits = FALSE) || !exists(x.measure, | |
where = long.unit.names, inherits = FALSE)) { | |
stop(paste("Wrong argument types: First argument must be of type", | |
"'prediction'; second and optional third argument must", | |
"be available performance measures!")) | |
} | |
if (exists(x.measure, where = obligatory.x.axis, inherits = FALSE)) { | |
message <- paste("The performance measure", x.measure, | |
"can only be used as 'measure', because it has", | |
"the following obligatory 'x.measure':\n", get(x.measure, | |
envir = obligatory.x.axis)) | |
stop(message) | |
} | |
if (exists(measure, where = obligatory.x.axis, inherits = FALSE)) { | |
x.measure <- get(measure, envir = obligatory.x.axis) | |
} | |
if (x.measure == "cutoff" || exists(measure, where = obligatory.x.axis, | |
inherits = FALSE)) { | |
optional.args <- list(...) | |
argnames <- c() | |
if (exists(measure, where = optional.arguments, inherits = FALSE)) { | |
argnames <- get(measure, envir = optional.arguments) | |
default.arglist <- list() | |
for (i in 1:length(argnames)) { | |
default.arglist <- c(default.arglist, get(paste(measure, | |
":", argnames[i], sep = ""), envir = default.values, | |
inherits = FALSE)) | |
} | |
names(default.arglist) <- argnames | |
for (i in 1:length(argnames)) { | |
templist <- list(optional.args, default.arglist[[i]]) | |
names(templist) <- c("arglist", argnames[i]) | |
optional.args <- do.call(".farg", templist) | |
} | |
} | |
optional.args <- .select.args(optional.args, argnames) | |
function.name <- get(measure, envir = function.names) | |
x.values <- list() | |
y.values <- list() | |
for (i in 1:length(prediction.obj@predictions)) { | |
argumentlist <- .sarg(optional.args, predictions = prediction.obj@predictions[[i]], | |
labels = prediction.obj@labels[[i]], cutoffs = prediction.obj@cutoffs[[i]], | |
fp = prediction.obj@fp[[i]], tp = prediction.obj@tp[[i]], | |
fn = prediction.obj@fn[[i]], tn = prediction.obj@tn[[i]], | |
n.pos = [email protected][[i]], n.neg = [email protected][[i]], | |
n.pos.pred = [email protected][[i]], | |
n.neg.pred = [email protected][[i]]) | |
ans <- do.call(function.name, argumentlist) | |
if (!is.null(ans[[1]])) | |
x.values <- c(x.values, list(ans[[1]])) | |
y.values <- c(y.values, list(ans[[2]])) | |
} | |
if (!(length(x.values) == 0 || length(x.values) == length(y.values))) { | |
stop("Consistency error.") | |
} | |
return(new("performance", x.name = get(x.measure, envir = long.unit.names), | |
y.name = get(measure, envir = long.unit.names), alpha.name = "none", | |
x.values = x.values, y.values = y.values, alpha.values = list())) | |
} | |
else { | |
perf.obj.1 <- performance(prediction.obj, measure = x.measure, | |
...) | |
perf.obj.2 <- performance(prediction.obj, measure = measure, | |
...) | |
return(.combine.performance.objects(perf.obj.1, perf.obj.2)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment