Skip to content

Instantly share code, notes, and snippets.

@artemklevtsov
Last active April 9, 2018 13:32
Show Gist options
  • Select an option

  • Save artemklevtsov/303b3cbe362064e26a2afbfe751d16f2 to your computer and use it in GitHub Desktop.

Select an option

Save artemklevtsov/303b3cbe362064e26a2afbfe751d16f2 to your computer and use it in GitHub Desktop.
Binning (discretize) variables based on CART
library(recipes)
library(dplyr)
library(tibble)
step_bin <- function(recipe, ..., role = "predictor", trained = FALSE,
threshold = .05, n.group = NULL, woe = FALSE, objects = NULL,
skip = FALSE) {
if (threshold <= 0)
stop("`threshold` should be greater than zero", call. = FALSE)
if (threshold >= 1)
stop("`threshold` should be less than one", call. = FALSE)
if (!is.null(n.group)) {
if (length(n.group) > 2L)
stop("'n.group' length should be one or two", call. = FALSE)
if (any(n.group < 2L))
stop("'n.group' should be greater then one", call. = FALSE)
}
terms <- rlang::quos(...)
if (length(terms) == 0L)
stop("Please supply at least one variable specification. See ?selections.")
add_step(
recipe,
step_bin_new(
terms = terms,
role = role,
trained = trained,
threshold = threshold,
n.group = n.group,
woe = woe,
objects = objects,
skip = skip
)
)
}
step_bin_new <- function(terms = NULL, role = "predictor", trained = FALSE,
threshold = NULL, n.group = NULL, woe = FALSE, objects = NULL,
skip = FALSE) {
step(
subclass = "bin",
terms = terms,
role = role,
trained = trained,
threshold = threshold,
n.group = n.group,
woe = woe,
objects = objects,
skip = skip
)
}
prep.step_bin <- function(x, training, info = NULL, ...) {
col_names <- terms_select(terms = x$terms, info = info)
target <- info$variable[info$role == "outcome"]
objects <- lapply(training[, col_names],
binning,
min.perc = x$threshold,
target = getElement(training, target),
n.group = x$n.group)
objects <- objects[lengths(objects) > 0L]
attr(objects, "target") <- target
step_bin_new(
terms = x$terms,
role = x$role,
trained = TRUE,
threshold = x$threshold,
n.group = x$n.group,
woe = x$woe,
objects = objects,
skip = x$skip
)
}
#' @importFrom tibble as_tibble is_tibble
#' @export
bake.step_bin <- function(object, newdata, ...) {
target <- getElement(newdata, attr(object$objects, "target"))
suffix <- ifelse(object$woe, ".woe", ".bin")
res <- purrr::map2_df(object$objects, names(object$objects), function(x, i) {
if (x$type == "numeric") {
tmp <- cut(getElement(newdata, i), x$breaks, x$labels)
}
if (x$type == "factor") {
tmp <- factor(getElement(newdata, i))
to_drop <- setdiff(levels(tmp), unlist(x$levels))
if (length(to_drop) > 0L)
tmp[tmp %in% to_drop] <- NA
levels(tmp) <- setNames(x$levels, x$labels)
}
if (object$woe) {
pb <- prop.table(table(tmp, target), margin = 2)
woe <- log(pb[, 1] / pb[, 2])
tmp <- woe[match(tmp, names(woe))]
}
return(tmp)
})
res <- rename_all(res, funs(paste0(., suffix)))
newdata <- bind_cols(newdata, res)
if (!is_tibble(newdata))
newdata <- as_tibble(newdata)
newdata
}
binning <- function(variable, target, min.perc = 0.05, n.group = NULL, ...) {
if (is.character(variable))
variable <- as.factor(variable)
if (length(n.group) == 1L)
n.group <- c(n.group, n.group)
minbucket <- floor(min.perc * length(target))
ctrl <- party::ctree_control(testtype = "Univariate", minbucket = minbucket)
fit <- party::ctree(target ~ variable, data_frame(target, variable), controls = ctrl)
if (!is.null(n.group)) {
tree_len <- function(fit) length(unique(fit@where))
check_exit <- function() {
cri_val <- ctrl@gtctrl@mincriterion
cri_lim <- c(0.01, 0.99)
groups <- tree_len(fit)
!between(groups, n.group[1], n.group[2]) && between(cri_val, cri_lim[1], cri_lim[2])
}
while (check_exit()) {
multiplier <- ifelse(tree_len(fit) > median(n.group), 1.01, 0.99)
ctrl@gtctrl@mincriterion <- ctrl@gtctrl@mincriterion + multiplier
fit <- party::ctree(target ~ variable, data, controls = ctrl)
}
}
if (length(unique(fit@where)) < 2L)
return(NULL)
if (is.factor(variable)) {
lvls <- lapply(split(variable, fit@where),
function(x) list(sort(as.character(unique(x)))))
bins <- length(lvls)
lbls <- sapply(lvls, paste, collapse = "|")
res <- list(bins = bins, levels = lvls, labels = lbls, type = "factor")
}
if (is.numeric(variable)) {
breaks <- sapply(split(variable, fit@where), max, na.rm = TRUE)
breaks <- signif(c(-Inf, sort(breaks)[-length(breaks)], Inf), 3L)
bins <- length(breaks)
labels <- paste0("(", breaks[-bins], ",", breaks[-1], "]")
labels[1] <- gsub("(", "[", labels[1], fixed = TRUE)
res <- list(bins = bins, breaks = breaks, labels = labels, type = "numeric")
}
return(res)
}
@artemklevtsov
Copy link
Copy Markdown
Author

artemklevtsov commented Apr 9, 2018

Example:

# install.pacakges(c("caret", "recipes"))

library(recipes)

data("GermanCredit", package = "caret")

source("https://gist.githubusercontent.com/artemklevtsov/303b3cbe362064e26a2afbfe751d16f2/raw/dfa0c243fce23df1d09da2a2d36f6206ff7f63c3/bin.R")

rec <- recipe(Class ~ Amount + Age + Duration, GermanCredit) %>% 
    step_bin(all_predictors()) %>% 
    step_bin(ends_with(".bin"), woe = TRUE)
train_prep <- prep(rec, GermanCredit)
train_data <- bake(train_prep, GermanCredit)
train_data

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment