Last active
March 12, 2018 14:58
-
-
Save brandonwillard/710a7bab8526db394bad to your computer and use it in GitHub Desktop.
Slow fix for creating large sparse matrices with interactions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(Rcpp) | |
#library(RcppProgress) | |
#library(RcppParallel) | |
{ | |
code = ' | |
// [[Rcpp::depends(RcppProgress, RcppArmadillo, RcppParallel)]] | |
#define ARMA_64BIT_WORD | |
#include <RcppArmadillo.h> | |
#include <RcppParallel.h> | |
#include <progress.hpp> | |
#include <iterator> | |
#include <algorithm> | |
#include <iostream> | |
struct SProd : public RcppParallel::Worker { | |
// source matrix | |
const arma::sp_mat& spA; | |
const arma::sp_mat& spB; | |
const std::set<int>& anrows; | |
const std::set<int>& bnrows; | |
const Progress& p; | |
//std::set<int>::iterator anrows_iter; | |
//std::set<int>::iterator bnrows_iter; | |
// destination matrix | |
arma::sp_mat& output; | |
// initialize with source and destination | |
SProd(const arma::sp_mat& spA, | |
const arma::sp_mat& spB, | |
const std::set<int>& anrows, | |
const std::set<int>& bnrows, | |
const Progress& p, | |
arma::sp_mat& output) | |
: spA(spA), spB(spB), anrows(anrows), bnrows(bnrows), p(p), output(output) | |
//, anrows_iter(std::unique(spA.row_indices, spA.row_indices + spA.n_nonzero)) | |
//, bnrows_iter(std::unique(spB.row_indices, spB.row_indices + spB.n_nonzero)) | |
{ | |
} | |
// take the square root of the range of elements requested | |
void operator()(const std::size_t begin, const std::size_t end) { | |
std::cout << "begin=" << begin << ", end=" << end << std::endl << std::flush; | |
std::set<int>::const_iterator b_start = bnrows.begin(); | |
std::advance(b_start, begin); | |
std::set<int>::const_iterator b_end = bnrows.begin(); | |
std::advance(b_end, end+1); | |
for (std::set<int>::const_iterator b_it = b_start; b_it != b_end; b_it++) { | |
//std::cout << "b_row=" << *b_it << std::endl << std::flush; | |
if (Progress::check_abort()) | |
break; //return(R_NilValue); | |
for (std::set<int>::const_iterator a_it = anrows.begin(); a_it != anrows.end(); a_it++) { | |
if (Progress::check_abort()) | |
break; //return(R_NilValue); | |
const int k = (*b_it) * spA.n_rows + (*a_it); | |
//std::cout << "k=" << k << std::endl << std::flush; | |
output.row(k) = spA.row(*a_it) % spB.row(*b_it); | |
//p.increment(); | |
} | |
} | |
} | |
}; | |
// [[Rcpp::export("sparse.row.idx.mult")]] | |
SEXP sparse_row_idx_mult(arma::sp_mat spA, arma::sp_mat spB) { | |
Progress p(0, false); | |
std::set<int> anrows(spA.row_indices, spA.row_indices + spA.n_nonzero); | |
std::set<int> bnrows(spB.row_indices, spB.row_indices + spB.n_nonzero); | |
const int bnrows_len = bnrows.size(); | |
const int res_rows = spB.n_rows * spA.n_rows; | |
arma::sp_mat C(res_rows, spA.n_cols); | |
std::cout << "B_nonzero_rows=" << bnrows_len << std::endl << std::flush; | |
std::cout << "res_rows=" << res_rows << std::endl << std::flush; | |
SProd prodWorker(spA, spB, anrows, bnrows, p, C); | |
RcppParallel::parallelFor(0, bnrows_len-1, prodWorker, 1000); | |
Rcpp::S4 Cout(Rcpp::wrap(C)); | |
return(Cout); | |
} | |
' | |
sourceCpp(code=code) | |
} | |
old.sparse2int = Matrix:::sparse2int | |
fixed.sparse2int = function (X, Y, do.names = TRUE, forceSparse = FALSE, verbose = FALSE) { | |
if (do.names) { | |
dnx <- dimnames(X) | |
dny <- dimnames(Y) | |
} | |
dimnames(Y) <- dimnames(X) <- list(NULL, NULL) | |
nx <- nrow(X) | |
ny <- nrow(Y) | |
r <- if ((nX <- is.numeric(X)) | (nY <- is.numeric(Y))) { | |
if (nX) { | |
if (nY || nx > 1) { | |
F <- if (forceSparse) function(m) .Call(dense_to_Csparse, m) else identity | |
F(sparse.row.idx.mult(X, Y)) | |
} else { | |
r <- Y | |
dp <- Y@p[-1] - Y@p[-(Y@Dim[2] + 1L)] | |
r@x <- X[dp == 1L] * Y@x | |
r | |
} | |
} else { | |
if (ny == 1) { | |
r <- X | |
dp <- X@p[-1] - X@p[-(X@Dim[2] + 1L)] | |
r@x <- Y[dp == 1L] * X@x | |
r | |
} else { | |
sparse.row.idx.mult(X, Y) | |
} | |
} | |
} else { | |
sparse.row.idx.mult(X, Y) | |
} | |
if (verbose) | |
cat(sprintf(" sp..2int(%s[%d],%s[%d]) ", if (nX) "<N>" else "<sparse>", nx, if (nY) "<N>" else "<sparse>", ny)) | |
if (do.names) { | |
if (!is.null(dim(r)) && !is.null(nX <- dnx[[1]]) && !is.null(nY <- dny[[1]])) | |
rownames(r) <- outer(nX, nY, paste, sep = ":") | |
} | |
return(r) | |
} | |
require(R.utils) | |
reassignInPackage("sparse2int", pkgName="Matrix", fixed.sparse2int) | |
require(testthat) | |
require(Matrix) | |
test_that("test sparse prod only", { | |
test.dat = data.frame(a=gl(2,4), b=gl(4,2)) | |
X = drop0(t(model.matrix(~0+a, test.dat))) | |
Y = drop0(t(model.matrix(~0+b, test.dat))) | |
new.res = sparse.row.idx.mult(X,Y) | |
old.res = old.sparse2int(X,Y) | |
expect_true(all(new.res == old.res)) | |
}) | |
#test_that("large interaction, one term, no intercept", { | |
# test.dat = data.frame(a=gl(10,1e6), b=gl(1e6,10)) | |
# x.fml = as.formula(~ 0 + a + a:b) | |
# fix.m.mat = sparse.model.matrix(x.fml, test.dat, verbose=TRUE) | |
# rm(test.dat) | |
# # just make sure it doesn't err out | |
# expect_true(!is.null(fix.m.mat)) | |
# rm(fix.m.mat) | |
# gc() | |
#}) | |
test_that("interaction, one term, no intercept", { | |
test.dat = data.frame(a=gl(2,4), b=gl(4,2)) | |
x.fml = as.formula(~ 0 + a + a:b) | |
m.mat = model.matrix(x.fml, test.dat) | |
fix.m.mat = sparse.model.matrix(x.fml, test.dat) | |
expect_true(all(m.mat == fix.m.mat)) | |
}) | |
test_that("all terms, no intercept", { | |
test.dat = data.frame(a=gl(2,4), b=gl(4,2)) | |
x.fml = as.formula(~ 0 + a + b + a:b) | |
m.mat = model.matrix(x.fml, test.dat) | |
fix.m.mat = sparse.model.matrix(x.fml, test.dat) | |
expect_true(all(m.mat == fix.m.mat)) | |
}) | |
test_that("no interaction, no intercept", { | |
test.dat = data.frame(a=gl(2,4), b=gl(4,2)) | |
x.fml = as.formula(~ 0 + a + b) | |
m.mat = model.matrix(x.fml, test.dat) | |
fix.m.mat = sparse.model.matrix(x.fml, test.dat) | |
expect_true(all(m.mat == fix.m.mat)) | |
}) | |
test_that("one term, interaction", { | |
test.dat = data.frame(a=gl(2,4), b=gl(4,2)) | |
x.fml = as.formula(~ a + a:b) | |
m.mat = model.matrix(x.fml, test.dat) | |
fix.m.mat = sparse.model.matrix(x.fml, test.dat) | |
expect_true(all(m.mat == fix.m.mat)) | |
}) | |
test_that("all terms", { | |
test.dat = data.frame(a=gl(2,4), b=gl(4,2)) | |
x.fml = as.formula(~ .*.) | |
m.mat = model.matrix(x.fml, test.dat) | |
fix.m.mat = sparse.model.matrix(x.fml, test.dat) | |
expect_true(all(m.mat == fix.m.mat)) | |
}) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment