-
-
Save krlmlr/e24f030afb501c48538549d161bef41b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--- | |
title: "Applying a function over rows of a data frame" | |
author: "Winston Chang" | |
output: html_document | |
--- | |
```{r setup, include=FALSE} | |
knitr::opts_chunk$set(collapse = TRUE, comment = "#>") | |
``` | |
[Source](https://gist.github.com/wch/0e564def155d976c04dd28a876dc04b4) for this document. | |
@dattali [asked](https://twitter.com/daattali/status/761058049859518464), "what's a safe way to iterate over rows of a data frame?" The example was to convert each row into a list and return a list of lists, indexed first by column, then by row. | |
A number of people gave suggestions on Twitter, which I've collected here. I've benchmarked these methods with data of various sizes; scroll down to see a plot of times. | |
```{r message=FALSE} | |
library(purrr) | |
library(dplyr) | |
library(tidyr) | |
# @dattali | |
# Using apply (only safe when all cols are same type) | |
f_apply <- function(df) { | |
apply(df, 1, function(row) as.list(row)) | |
} | |
# @drob | |
# split + lapply | |
f_split_lapply <- function(df) { | |
df <- split(df, seq_len(nrow(df))) | |
lapply(df, function(row) as.list(row)) | |
} | |
# @winston_chang | |
# lapply over row indices | |
f_lapply_row <- function(df) { | |
lapply(seq_len(nrow(df)), function(i) as.list(df[i,,drop=F])) | |
} | |
# @winston_chang | |
# lapply + lapply: Treat data frame as list, and the slice out lists | |
f_lapply_lapply <- function(df) { | |
cols <- seq_len(length(df)) | |
names(cols) <- names(df) | |
lapply(seq_len(nrow(df)), function(row) { | |
lapply(cols, function(col) { | |
df[[col]][[row]] | |
}) | |
}) | |
} | |
# @winston_chang | |
# purrr::by_row | |
f_by_row <- function(df) { | |
res <- df %>% by_row(function(row) as.list(row)) | |
res$.out | |
} | |
# @JennyBryan | |
# purrr::pmap | |
f_pmap <- function(df) { | |
pmap(df, list) | |
} | |
# purrr::pmap, but coerce df to a list first | |
f_pmap_aslist <- function(df) { | |
pmap(as.list(df), list) | |
} | |
# @krlmlr | |
# dplyr::rowwise | |
f_rowwise <- function(df) { | |
df %>% rowwise %>% do(row = list(.)) | |
} | |
``` | |
Benchmark each of them, using data sets with varying numbers of rows: | |
```{r} | |
run_benchmark <- function(nrow) { | |
# Make some data | |
df <- data.frame( | |
x = rnorm(nrow), | |
y = runif(nrow), | |
z = runif(nrow) | |
) | |
res <- list( | |
apply = system.time(f_apply(df)), | |
split_lapply = system.time(f_split_lapply(df)), | |
lapply_row = system.time(f_lapply_row(df)), | |
lapply_lapply = system.time(f_lapply_lapply(df)), | |
by_row = system.time(f_by_row(df)), | |
pmap = system.time(f_pmap(df)), | |
pmap_aslist = system.time(f_pmap_aslist(df)), | |
rowwise = system.time(f_rowwise(df)) | |
) | |
# Get elapsed times | |
res <- lapply(res, `[[`, "elapsed") | |
res$nrow <- nrow | |
res | |
} | |
# Run the benchmarks | |
all_times <- lapply(1:5, function(n) { | |
run_benchmark(10^n) | |
}) | |
# Convert to data frame | |
times <- lapply(all_times, function(row) { | |
as.data.frame(row) | |
}) | |
times <- do.call(rbind, times) | |
# Convert to long format | |
times <- gather(times, method, seconds, -nrow) | |
# Set order of methods, for plots | |
times$method <- factor(times$method, | |
levels = c("apply", "split_lapply", "lapply_row", "lapply_lapply", "by_row", | |
"pmap", "pmap_aslist", "rowwise") | |
) | |
times | |
``` | |
## Plot times | |
This plot shows the number of seconds needed to process n rows, for each method. Both the x and y use log scales, so each step along the x scale represents a 10x increase in number of rows, and each step along the y scale represents a 10x increase in time. | |
```{r message=FALSE} | |
library(ggplot2) | |
library(scales) | |
# plot with log-log axes | |
ggplot(times, aes(x = nrow, y = seconds, colour = method)) + | |
geom_point() + | |
geom_line() + | |
annotation_logticks(sides = "trbl") + | |
theme_bw() + | |
scale_y_continuous(trans = log10_trans(), | |
breaks = trans_breaks("log10", function(x) 10^x), | |
labels = trans_format("log10", math_format(10^.x)), | |
minor_breaks = NULL) + | |
scale_x_continuous(trans = log10_trans(), | |
breaks = trans_breaks("log10", function(x) 10^x), | |
labels = trans_format("log10", math_format(10^.x)), | |
minor_breaks = NULL) | |
``` |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment