Last active
August 29, 2015 14:14
-
-
Save lionel-/0b41c6b9d3554725807a to your computer and use it in GitHub Desktop.
Making lowliner better understand data frames
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
set_groups <- function(.d, .cols = NULL) { | |
stopifnot(is.data.frame(.d)) | |
if (is.null(.cols)) { | |
return(group_by_(.d, .dots = list())) | |
} | |
if (is.numeric(.cols)) { | |
.cols <- names(.d)[.cols] | |
} | |
.cols %>% map_call(dplyr::group_by_, .data = .d) | |
} | |
unset_groups <- function(.d) { | |
set_groups(.d, NULL) | |
} | |
by_group <- function(.d, .f, ...) { | |
if (is.formula(.f)) { | |
.f <- lowliner:::as_function(.f) | |
} | |
if (!inherits(.d, "grouped_df")) { | |
return(.f(.d, ...)) | |
} | |
indices <- attr(.d, "indices") %>% map(partial(`+`, 1)) | |
classes <- attr(.d, "vars") %>% vapply(as.character, character(1)) | |
unpiled <- .d[-match(classes, names(.d))] %>% | |
lapply(function(col) { | |
lapply(indices, . %>% col[.]) | |
}) | |
# Subset groups and apply function | |
out <- lapply(seq_along(indices), function(i) { | |
rows <- lapply(unpiled, function(col) col[[i]]) | |
.f(dplyr::as_data_frame(rows), ...) | |
}) | |
# If data frame, record number of rows in each group before | |
# merging. If any other kind, return a list-column. | |
out <- | |
if (every(out, is.data.frame)) { | |
lengths <- out %>% vapply(nrow, numeric(1)) | |
dplyr::bind_rows(out) | |
} else { | |
lengths <- rep(1, length(out)) | |
list(out = out) %>% dplyr::as_data_frame() | |
} | |
# Recycle labels to the output size in each group. dplyr's subset | |
# method is used because it always return a data frame | |
labels <- attr(.d, "labels") | |
n_groups <- nrow(labels) | |
seq <- Map(rep, seq_len(n_groups), lengths) %>% unlist() | |
labels <- dplyr::tbl_df(labels)[seq, ] | |
dplyr::bind_cols(labels, out) %>% dplyr::tbl_df() | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <Rcpp.h> | |
#include <dplyr.h> | |
// [[Rcpp::depends(dplyr)]] | |
using namespace Rcpp; | |
int is_object(SEXP obj); | |
List apply_slices(List data, Function fun) { | |
ListOf<IntegerVector> indices(data.attr("indices")); | |
int n_slices = indices.size(); | |
std::vector<int> slice_sizes(n_slices); | |
for (int i = 0; i < n_slices; ++i) { | |
slice_sizes[i] = indices[i].size(); | |
} | |
CharacterVector classes = Rcpp::CharacterVector::create( | |
"tbl_df", "tbl", "data.frame" | |
); | |
dplyr::DataFrameVisitors visitors(data); | |
// Apply fun on each slice | |
List out(n_slices); | |
for (int i = 0; i < n_slices; ++i) { | |
out[i] = fun(visitors.subset(indices[i], classes)); | |
} | |
return out; | |
} | |
// [[Rcpp::export]] | |
List by_slice_impl(const List data, Function fun) { | |
List out = apply_slices(data, fun); | |
int all_objects = 1; | |
for (int i = 0; i != out.size(); ++i) { | |
all_objects *= is_object(out[i]) * !Rf_inherits(out[i], "data.frame"); | |
} | |
// Make a list-column only if all outputs are non-data frame | |
// objects. In all other cases, we let dplyr::bind_rows() check that | |
// the outputs are compatible. | |
if (all_objects) { | |
for (int i = 0; i != out.size(); ++i) { | |
out[i] = List::create(_[".out"] = List::create(out[i])); | |
List out_slice(out[i]); | |
out_slice.attr("row.names") = IntegerVector::create(IntegerVector::get_na(), -1); | |
out_slice.attr("class") = CharacterVector::create("tbl_df", "data.frame"); | |
} | |
} | |
return out; | |
} | |
// [[Rcpp::export]] | |
List subset_slices(const List data) { | |
ListOf<IntegerVector> indices(data.attr("indices")); | |
int n_slices = indices.size(); | |
std::vector<int> slice_sizes(n_slices); | |
for (int i = 0; i < n_slices; ++i) { | |
slice_sizes[i] = indices[i].size(); | |
} | |
CharacterVector classes = Rcpp::CharacterVector::create( | |
"tbl_df", "tbl", "data.frame" | |
); | |
dplyr::DataFrameVisitors visitors(data); | |
List out(n_slices); | |
for (int i = 0; i < n_slices; ++i) { | |
out[i] = visitors.subset(indices[i], classes); | |
} | |
return out; | |
} | |
/*** R | |
# Calls ..f from Rcpp | |
by_slice2 <- function(.x, ..f) { | |
by_slice_impl(.x, ..f) %>% dplyr::bind_rows() | |
} | |
# Calls ..f from R | |
by_slice3 <- function(.x, ..f) { | |
out <- subset_slices(.x) %>% lapply(..f) | |
if (every(out, is.object)) { | |
out <- lapply(out, function(x) list(.out = list(x)) %>% dplyr::as_data_frame()) | |
} | |
dplyr::bind_rows(out) | |
} | |
data <- rerun(1000, mtcars) %>% dplyr::bind_rows() %>% group_by(cyl, vs) | |
# R version wildly faster than Rcpp one, a bit slower than Rcpp two | |
fnu <- partial(lm, disp ~ gear) | |
microbenchmark( | |
R = data %>% by_slice(fnu), | |
dplyr = data %>% do(as_data_frame(list(.out = list(fnu(.))))), | |
Cpp1 = data %>% by_slice2(fnu), | |
Cpp2 = data %>% by_slice3(fnu) | |
) | |
# Now first Rcpp version is best, almost on par with dplyr | |
fnu2 <- partial(map, .f = mean) | |
microbenchmark( | |
R = data %>% by_slice(fnu2), | |
dplyr = data %>% summarise_each(funs(mean)), | |
Cpp1 = data %>% by_slice2(fnu2), | |
Cpp2 = data %>% by_slice3(fnu2) | |
) | |
*/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
map_rows <- function(.d, .f, ..., .trace = TRUE) { | |
out <- map_n(.d, .f, ...) %>% | |
lapply(coerce_rows) %>% | |
dplyr::bind_rows() | |
if (.trace) { | |
dplyr::bind_cols(.d, out) %>% dplyr::tbl_df() | |
} else { | |
out | |
} | |
} | |
coerce_rows <- function(x) { | |
if (is_bare_atomic(x)) { | |
x %>% | |
as.list() %>% | |
setNames(seq_along(x)) %>% | |
dplyr::as_data_frame() | |
} else if (is.data.frame(x)) { | |
x | |
} else { | |
dplyr::data_frame(out = list(x)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment