Skip to content

Instantly share code, notes, and snippets.

@tmastny
Created February 14, 2019 21:37
Show Gist options
  • Select an option

  • Save tmastny/002cb06e28c2f6fadceaaa65018660c4 to your computer and use it in GitHub Desktop.

Select an option

Save tmastny/002cb06e28c2f6fadceaaa65018660c4 to your computer and use it in GitHub Desktop.
Trying to get custom models for thresholds
library(caret)
library(dplyr)
threshold_method <- function(method) {
thresh_code <- getModelInfo(method, regex = FALSE)[[1]]
thresh_code$type <- c("Classification")
thresh_code$parameters <- do.call(
rbind, list(thresh_code$parameters, data.frame(
parameter = "thershold",
class = "numeric",
label = "Probability Cutoff"
)))
thresh_code$grid <- function(x, y, len = NULL, search = "grid") {
grid <- thresh_code$grid(x, y, len, search)
if(search == "grid") {
grid <- tidyr::crossing(
grid,
threshold = seq(.01, .99, length = len)
)
} else {
grid <- tidyr::crossing(
grid,
threshold = runif(1, 0, size = len)
)
}
grid
}
thresh_code$loop <- function(grid) {
loop <- grid %>%
group_by_at(vars(-threshold)) %>%
summarise(threshold = max(threshold))
unique_thresh <- grid %>%
filter(threshold != max(threshold)) %>%
.$threshold %>%
unique()
submodels <- purrr::map(1:nrow(loop_fn), ~data.frame(threshold = unique_thresh))
unique(grid$threshold)
list(loop = loop, submodels = submodels)
}
thresh_code$predict <- function(modelFit, newdata, submodels = NULL) {
class1Prob <- predict(modelFit, newdata, type = "prob")[, modelFit$obsLevels[1]]
## Raise the threshold for class #1 and a higher level of
## evidence is needed to call it class 1 so it should
## decrease sensitivity and increase specificity
out <- ifelse(
class1Prob >= modelFit$tuneValue$threshold,
modelFit$obsLevels[1],
modelFit$obsLevels[2]
)
if(!is.null(submodels)) {
tmp2 <- out
out <- vector(mode = "list", length = length(submodels$threshold))
out[[1]] <- tmp2
for(i in seq(along = submodels$threshold)) {
out[[i + 1]] <- ifelse(
class1Prob >= submodels$threshold[[i]],
modelFit$obsLevels[1],
modelFit$obsLevels[2])
}
}
out
}
thresh_code$prob <- function(modelFit, newdata, submodels = NULL) {
out <- as.data.frame(predict(modelFit, newdata, type = "prob"))
if(!is.null(submodels)) {
probs <- out
out <- vector(mode = "list", length = length(submodels$threshold) + 1)
out <- lapply(out, function(x) probs)
}
out
}
thresh_code
}
fourStats <- function (data, lev = levels(data$obs), model = NULL) {
## This code will get use the area under the ROC curve and the
## sensitivity and specificity values using the current candidate
## value of the probability threshold.
out <- c(twoClassSummary(data, lev = levels(data$obs), model = NULL))
## The best possible model has sensitivity of 1 and specificity of 1.
## How far are we from that value?
coords <- matrix(c(1, 1, out["Spec"], out["Sens"]),
ncol = 2,
byrow = TRUE)
colnames(coords) <- c("Spec", "Sens")
rownames(coords) <- c("Best", "Current")
c(out, Dist = dist(coords)[1])
}
set.seed(442)
trainingSet <- twoClassSim(n = 500, intercept = -16)
testingSet <- twoClassSim(n = 500, intercept = -16)
## Class frequencies
table(trainingSet$Class)
set.seed(949)
mod1 <- train(Class ~ ., data = trainingSet,
method = thresh_code, # threshold_method('rf'),
metric = "Dist",
maximize = FALSE,
tuneLength = 20,
ntree = 1000,
trControl = trainControl(method = "repeatedcv",
repeats = 5,
classProbs = TRUE,
summaryFunction = fourStats))
mod1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment