Last active
April 9, 2018 13:32
-
-
Save artemklevtsov/303b3cbe362064e26a2afbfe751d16f2 to your computer and use it in GitHub Desktop.
Binning (discretize) variables based on CART
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| library(recipes) | |
| library(dplyr) | |
| library(tibble) | |
| step_bin <- function(recipe, ..., role = "predictor", trained = FALSE, | |
| threshold = .05, n.group = NULL, woe = FALSE, objects = NULL, | |
| skip = FALSE) { | |
| if (threshold <= 0) | |
| stop("`threshold` should be greater than zero", call. = FALSE) | |
| if (threshold >= 1) | |
| stop("`threshold` should be less than one", call. = FALSE) | |
| if (!is.null(n.group)) { | |
| if (length(n.group) > 2L) | |
| stop("'n.group' length should be one or two", call. = FALSE) | |
| if (any(n.group < 2L)) | |
| stop("'n.group' should be greater then one", call. = FALSE) | |
| } | |
| terms <- rlang::quos(...) | |
| if (length(terms) == 0L) | |
| stop("Please supply at least one variable specification. See ?selections.") | |
| add_step( | |
| recipe, | |
| step_bin_new( | |
| terms = terms, | |
| role = role, | |
| trained = trained, | |
| threshold = threshold, | |
| n.group = n.group, | |
| woe = woe, | |
| objects = objects, | |
| skip = skip | |
| ) | |
| ) | |
| } | |
| step_bin_new <- function(terms = NULL, role = "predictor", trained = FALSE, | |
| threshold = NULL, n.group = NULL, woe = FALSE, objects = NULL, | |
| skip = FALSE) { | |
| step( | |
| subclass = "bin", | |
| terms = terms, | |
| role = role, | |
| trained = trained, | |
| threshold = threshold, | |
| n.group = n.group, | |
| woe = woe, | |
| objects = objects, | |
| skip = skip | |
| ) | |
| } | |
| prep.step_bin <- function(x, training, info = NULL, ...) { | |
| col_names <- terms_select(terms = x$terms, info = info) | |
| target <- info$variable[info$role == "outcome"] | |
| objects <- lapply(training[, col_names], | |
| binning, | |
| min.perc = x$threshold, | |
| target = getElement(training, target), | |
| n.group = x$n.group) | |
| objects <- objects[lengths(objects) > 0L] | |
| attr(objects, "target") <- target | |
| step_bin_new( | |
| terms = x$terms, | |
| role = x$role, | |
| trained = TRUE, | |
| threshold = x$threshold, | |
| n.group = x$n.group, | |
| woe = x$woe, | |
| objects = objects, | |
| skip = x$skip | |
| ) | |
| } | |
| #' @importFrom tibble as_tibble is_tibble | |
| #' @export | |
| bake.step_bin <- function(object, newdata, ...) { | |
| target <- getElement(newdata, attr(object$objects, "target")) | |
| suffix <- ifelse(object$woe, ".woe", ".bin") | |
| res <- purrr::map2_df(object$objects, names(object$objects), function(x, i) { | |
| if (x$type == "numeric") { | |
| tmp <- cut(getElement(newdata, i), x$breaks, x$labels) | |
| } | |
| if (x$type == "factor") { | |
| tmp <- factor(getElement(newdata, i)) | |
| to_drop <- setdiff(levels(tmp), unlist(x$levels)) | |
| if (length(to_drop) > 0L) | |
| tmp[tmp %in% to_drop] <- NA | |
| levels(tmp) <- setNames(x$levels, x$labels) | |
| } | |
| if (object$woe) { | |
| pb <- prop.table(table(tmp, target), margin = 2) | |
| woe <- log(pb[, 1] / pb[, 2]) | |
| tmp <- woe[match(tmp, names(woe))] | |
| } | |
| return(tmp) | |
| }) | |
| res <- rename_all(res, funs(paste0(., suffix))) | |
| newdata <- bind_cols(newdata, res) | |
| if (!is_tibble(newdata)) | |
| newdata <- as_tibble(newdata) | |
| newdata | |
| } | |
| binning <- function(variable, target, min.perc = 0.05, n.group = NULL, ...) { | |
| if (is.character(variable)) | |
| variable <- as.factor(variable) | |
| if (length(n.group) == 1L) | |
| n.group <- c(n.group, n.group) | |
| minbucket <- floor(min.perc * length(target)) | |
| ctrl <- party::ctree_control(testtype = "Univariate", minbucket = minbucket) | |
| fit <- party::ctree(target ~ variable, data_frame(target, variable), controls = ctrl) | |
| if (!is.null(n.group)) { | |
| tree_len <- function(fit) length(unique(fit@where)) | |
| check_exit <- function() { | |
| cri_val <- ctrl@gtctrl@mincriterion | |
| cri_lim <- c(0.01, 0.99) | |
| groups <- tree_len(fit) | |
| !between(groups, n.group[1], n.group[2]) && between(cri_val, cri_lim[1], cri_lim[2]) | |
| } | |
| while (check_exit()) { | |
| multiplier <- ifelse(tree_len(fit) > median(n.group), 1.01, 0.99) | |
| ctrl@gtctrl@mincriterion <- ctrl@gtctrl@mincriterion + multiplier | |
| fit <- party::ctree(target ~ variable, data, controls = ctrl) | |
| } | |
| } | |
| if (length(unique(fit@where)) < 2L) | |
| return(NULL) | |
| if (is.factor(variable)) { | |
| lvls <- lapply(split(variable, fit@where), | |
| function(x) list(sort(as.character(unique(x))))) | |
| bins <- length(lvls) | |
| lbls <- sapply(lvls, paste, collapse = "|") | |
| res <- list(bins = bins, levels = lvls, labels = lbls, type = "factor") | |
| } | |
| if (is.numeric(variable)) { | |
| breaks <- sapply(split(variable, fit@where), max, na.rm = TRUE) | |
| breaks <- signif(c(-Inf, sort(breaks)[-length(breaks)], Inf), 3L) | |
| bins <- length(breaks) | |
| labels <- paste0("(", breaks[-bins], ",", breaks[-1], "]") | |
| labels[1] <- gsub("(", "[", labels[1], fixed = TRUE) | |
| res <- list(bins = bins, breaks = breaks, labels = labels, type = "numeric") | |
| } | |
| return(res) | |
| } |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example: