Skip to content

Instantly share code, notes, and snippets.

@mrdwab
Created February 21, 2016 12:11
Show Gist options
  • Save mrdwab/b9391df98e1661fa46fb to your computer and use it in GitHub Desktop.
Save mrdwab/b9391df98e1661fa46fb to your computer and use it in GitHub Desktop.
Possible rewrite of `stratified`. Not considerably faster, but this seems to be easier to follow by refactoring the code.
## Helper functions. Won't bother exporting.
dt_check <- function(indt) {
if (!is.data.table(indt)) as.data.table(indt) else indt
}
g_s <- function(indt, group) indt[, .N, by = group]
g_n <- function(indt, group, size) indt[, list(ss = size), by = group]
g_f <- function(indt, group, size) indt[, list(ss = ceiling(.N * size)), by = group]
g_l <- function(indt, group, size) setnames(data.table(names(size), unname(size)), c(group, "ss"))[]
g_sel <- function(indt, select) {
if (is.null(names(select))) {
stop("'select' must be a named list")
}
if (!all(names(select) %in% names(indt))) {
stop("Please verify your 'select' argument")
}
temp <- vapply(names(select),
function(x) indt[[x]] %in% select[[x]],
logical(nrow(indt)))
indt[rowSums(temp) == length(select)]
}
g_comp <- function(indt, comparedt, group, replace) {
checker <- indt[comparedt, on = group]
if (checker[, any(ss > N)]) {
if (!isTRUE(replace)) {
checker_meta <- checker[which(ss > N)]
checker[, ss := pmin(ss, N)]
}
}
if (exists("checker_meta")) {
message("Some groups have fewew values than requested:\n")
message(paste(capture.output(checker_meta), collapse = "\n"))
message("\n")
}
checker[, N := NULL][]
}
stratified <- function(indt, group, size, select = NULL, replace = FALSE, bothSets = FALSE) {
indt <- dt_check(indt)
group <- splitstackshape:::Names(indt, group)
if (!is.null(select)) indt <- g_sel(indt, select)
compare_me <- g_s(indt, group)
n <- {
if (length(size) > 1) {
if (length(size) != nrow(compare_me)) {
stop("Incorrect number of groups specified in 'size'")
}
g_comp(g_l(indt, group, size), compare_me, group, replace)
} else if (size < 1) {
g_comp(g_f(indt, group, size), compare_me, group, replace)
} else {
g_comp(g_n(indt, group, size), compare_me, group, replace)
}
}
out1 <- indt[n, sample(.I, unlist(ss), replace), on = group, by = .EACHI]$V1
if (isTRUE(bothSets)) {
out2 <- indt[!sequence(nrow(indt)) %in% out1]
list(SAMP1 = indt[out1], SAMP2 = out2)
} else {
indt[out1]
}
}
@karagawa
Copy link

karagawa commented Sep 1, 2016

hi @mrdwab. stratified is a great function, generally, for cross validation. I am wondering whether you could make size = 1 available for sampling all observations in each stratum with replacement? I think that would be very useful for non-parametric bootstrapping. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment