Created
February 14, 2019 21:37
-
-
Save tmastny/002cb06e28c2f6fadceaaa65018660c4 to your computer and use it in GitHub Desktop.
Trying to get custom models for thresholds
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| library(caret) | |
| library(dplyr) | |
| threshold_method <- function(method) { | |
| thresh_code <- getModelInfo(method, regex = FALSE)[[1]] | |
| thresh_code$type <- c("Classification") | |
| thresh_code$parameters <- do.call( | |
| rbind, list(thresh_code$parameters, data.frame( | |
| parameter = "thershold", | |
| class = "numeric", | |
| label = "Probability Cutoff" | |
| ))) | |
| thresh_code$grid <- function(x, y, len = NULL, search = "grid") { | |
| grid <- thresh_code$grid(x, y, len, search) | |
| if(search == "grid") { | |
| grid <- tidyr::crossing( | |
| grid, | |
| threshold = seq(.01, .99, length = len) | |
| ) | |
| } else { | |
| grid <- tidyr::crossing( | |
| grid, | |
| threshold = runif(1, 0, size = len) | |
| ) | |
| } | |
| grid | |
| } | |
| thresh_code$loop <- function(grid) { | |
| loop <- grid %>% | |
| group_by_at(vars(-threshold)) %>% | |
| summarise(threshold = max(threshold)) | |
| unique_thresh <- grid %>% | |
| filter(threshold != max(threshold)) %>% | |
| .$threshold %>% | |
| unique() | |
| submodels <- purrr::map(1:nrow(loop_fn), ~data.frame(threshold = unique_thresh)) | |
| unique(grid$threshold) | |
| list(loop = loop, submodels = submodels) | |
| } | |
| thresh_code$predict <- function(modelFit, newdata, submodels = NULL) { | |
| class1Prob <- predict(modelFit, newdata, type = "prob")[, modelFit$obsLevels[1]] | |
| ## Raise the threshold for class #1 and a higher level of | |
| ## evidence is needed to call it class 1 so it should | |
| ## decrease sensitivity and increase specificity | |
| out <- ifelse( | |
| class1Prob >= modelFit$tuneValue$threshold, | |
| modelFit$obsLevels[1], | |
| modelFit$obsLevels[2] | |
| ) | |
| if(!is.null(submodels)) { | |
| tmp2 <- out | |
| out <- vector(mode = "list", length = length(submodels$threshold)) | |
| out[[1]] <- tmp2 | |
| for(i in seq(along = submodels$threshold)) { | |
| out[[i + 1]] <- ifelse( | |
| class1Prob >= submodels$threshold[[i]], | |
| modelFit$obsLevels[1], | |
| modelFit$obsLevels[2]) | |
| } | |
| } | |
| out | |
| } | |
| thresh_code$prob <- function(modelFit, newdata, submodels = NULL) { | |
| out <- as.data.frame(predict(modelFit, newdata, type = "prob")) | |
| if(!is.null(submodels)) { | |
| probs <- out | |
| out <- vector(mode = "list", length = length(submodels$threshold) + 1) | |
| out <- lapply(out, function(x) probs) | |
| } | |
| out | |
| } | |
| thresh_code | |
| } | |
| fourStats <- function (data, lev = levels(data$obs), model = NULL) { | |
| ## This code will get use the area under the ROC curve and the | |
| ## sensitivity and specificity values using the current candidate | |
| ## value of the probability threshold. | |
| out <- c(twoClassSummary(data, lev = levels(data$obs), model = NULL)) | |
| ## The best possible model has sensitivity of 1 and specificity of 1. | |
| ## How far are we from that value? | |
| coords <- matrix(c(1, 1, out["Spec"], out["Sens"]), | |
| ncol = 2, | |
| byrow = TRUE) | |
| colnames(coords) <- c("Spec", "Sens") | |
| rownames(coords) <- c("Best", "Current") | |
| c(out, Dist = dist(coords)[1]) | |
| } | |
| set.seed(442) | |
| trainingSet <- twoClassSim(n = 500, intercept = -16) | |
| testingSet <- twoClassSim(n = 500, intercept = -16) | |
| ## Class frequencies | |
| table(trainingSet$Class) | |
| set.seed(949) | |
| mod1 <- train(Class ~ ., data = trainingSet, | |
| method = thresh_code, # threshold_method('rf'), | |
| metric = "Dist", | |
| maximize = FALSE, | |
| tuneLength = 20, | |
| ntree = 1000, | |
| trControl = trainControl(method = "repeatedcv", | |
| repeats = 5, | |
| classProbs = TRUE, | |
| summaryFunction = fourStats)) | |
| mod1 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment