Skip to content

Instantly share code, notes, and snippets.

@jrosell
Created November 1, 2024 01:49
Show Gist options
  • Save jrosell/4224c3114a0cb1de62557bed4ac609c7 to your computer and use it in GitHub Desktop.
Save jrosell/4224c3114a0cb1de62557bed4ac609c7 to your computer and use it in GitHub Desktop.
library(testthat)
expected <- c(0L, 8L, 15L, 0L, 5L, 9L, 0L, 2L)
cumsum_cut_base <- function(x, cuts = NULL) {
g <- rep(0, length(x))
if (!is.null(cuts)) g[cuts] <- 1
unlist(unname(lapply(split(x, cumsum(g)), \(.x) head(c(0, cumsum(.x)), -1))))
}
cumsum_cut_rlang <- function(x, cuts) {
x[cuts-1] <- 0
res <- lapply(split(x, cumsum(x == 0)), \(.x) {
cumsum(.x)
}) |>
unlist() |>
unname()
n <- length(res)
to_fill <- numeric(n)
rlang::vec_poke_range(to_fill, 2, res, to = n -1)
to_fill
}
cumsum_cut_findInterval <- function(x, cut) {
cut <- c(1, cut)
lagged_cumsum <- c(0, cumsum(x)[-length(x)])
offsets <- lagged_cumsum[cut]
offset_vec <- offsets[findInterval(seq_along(x), cut)]
lagged_cumsum - offset_vec
}
code2 <- r"(
SEXP mikec_inplace(SEXP x_, SEXP cut_) {
int nx = length(x_);
int nc = length(cut_);
int *x = INTEGER(x_);
int *cut = INTEGER(cut_);
int last = x[0];
x[0] = 0;
int ci = 0;
if (cut[ci] == 1) {
// ignore
ci++;
}
for (int i = 1; i < nx; i++) {
int tmp = last;
last = x[i];
if (ci < nc && i == cut[ci] - 1) {
x[i] = 0;
ci++;
// skip repeated elements
while(ci < nc && cut[ci] == cut[ci - 1]) ci++;
} else {
x[i] = x[i - 1] + tmp;
}
}
return x_;
}
)"
callme::compile(code2)
cumsum_cut_r_loops <- function(x, cut) {
nx <- length(x)
nc <- length(cut)
last <- x[1]
x[1] <- 0
ci <- 1
if (cut[ci] == 1) {
ci <- ci + 1
}
for (i in 2:nx) {
tmp <- last
last <- x[i]
if (ci <= nc && i == cut[ci]) {
x[i] <- 0
ci <- ci + 1
while (ci <= nc && cut[ci] == cut[ci - 1]) {
ci <- ci + 1
}
} else {
x[i] <- x[i - 1] + tmp
}
}
return(x)
}
res <- cumsum_cut_r_loops(c(8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L), c(4L, 7L))
expected <- c(0L, 8L, 15L, 0L, 5L, 9L, 0L, 2L)
testthat::expect_equal(res, expected)
results <- bench::mark(
r_base = cumsum_cut_base(c(8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L), c(4L, 7L)),
r_rlang = cumsum_cut_rlang(c(8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L), c(4L, 7L)),
r_findInterval = cumsum_cut_findInterval(c(8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L), c(4L, 7L)),
r_loops = cumsum_cut_r_loops(c(8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L), c(4L, 7L)),
c_inplace = mikec_inplace(c(8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L), c(4L, 7L)),
min_time = 5,
max_iterations = 100*1000
)
results[, c("expression", "median")] |> print()
x <- sample(255, 800, replace = TRUE)
cut <- sort(sample(800, 32, replace = TRUE))
results <- bench::mark(
r_base = cumsum_cut_base(x, cut),
r_rlang = cumsum_cut_rlang(x, cut),
r_findInterval = cumsum_cut_findInterval(x, cut),
r_loops = cumsum_cut_r_loops(x, cut),
c_inplace = mikec_inplace(x, cut),
min_time = 5,
# max_iterations = 100*1000
)
results[, c("expression", "median")] |> print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment