Last active
March 12, 2020 10:01
-
-
Save boennecd/db151d693435c0bdbcec96f61e380014 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// [[Rcpp::depends(RcppArmadillo)]] | |
#include <RcppArmadillo.h> | |
/* d vech(chol(X)) / d vech(X). See mathoverflow.net/a/232129/134083 | |
* | |
* Args: | |
* X: symmetric positive definite matrix. | |
* upper: logical for whether vech denotes the upper triangular part. | |
*/ | |
// [[Rcpp::export]] | |
arma::mat dchol(arma::mat const &X, bool const upper = false) | |
{ | |
using arma::uword; | |
uword const ndim = X.n_rows, | |
nvech = (ndim * (ndim + 1L)) / 2L; | |
arma::mat out(nvech, nvech, arma::fill::zeros); | |
arma::mat const F = arma::chol(X), | |
Fi = arma::inv(arma::trimatu(F)); | |
/* class to do the computation */ | |
struct util { | |
using uword = arma::uword; | |
arma::mat const &F, &Fi; | |
uword const ndim = F.n_cols; | |
util(arma::mat const &F, arma::mat const &Fi): F(F), Fi(Fi) { | |
assert(F .n_rows == ndim); | |
assert(F .n_cols == ndim); | |
assert(Fi.n_cols == ndim); | |
assert(Fi.n_rows == ndim); | |
} | |
double operator() | |
(uword const i, uword const j, uword const k) const { | |
double out(0); | |
double const *f = &F.at(j, i), | |
*fik = &Fi.at(k, j), | |
mult = *fik; | |
out += *f++ * *fik / 2.; | |
fik += ndim; | |
for(uword m = j + 1; m < ndim; m++, fik += ndim) | |
out += *f++ * *fik; | |
out *= mult; | |
return out; | |
} | |
double operator() | |
(uword const i, uword const j, uword const k, uword const l) const { | |
double out(0); | |
double x(0); | |
double const *f = &F.at(j, i), | |
*fik = &Fi.at(k, j), | |
*fil = &Fi.at(l, j), | |
mult_k = *fik, | |
mult_l = *fil; | |
out += *f * *fik / 2.; | |
x += *f++ * *fil / 2.; | |
fik += ndim; | |
fil += ndim; | |
for(uword m = j + 1; m < ndim; m++, fik += ndim, fil += ndim){ | |
out += *f * *fik; | |
x += *f++ * *fil; | |
} | |
out *= mult_l; | |
x *= mult_k; | |
out += x; | |
return out; | |
} | |
}; | |
if(upper){ | |
/* get index map from index in vech that maps to a lower triangular | |
* matrix to one that maps to an upper triangular matrix. | |
* TODO: very slow... */ | |
auto im = [&](uword const idx){ | |
uword co(0), ro(0); | |
{ | |
uword dum(idx), remain(ndim); | |
while(dum >= remain){ | |
++co; | |
dum -= remain--; | |
} | |
ro = co + dum; | |
} | |
return (ro * (ro + 1L)) / 2L + co; | |
}; | |
uword r(0); | |
for(uword j = 0; j < ndim; j++) | |
for(uword i = j; i < ndim; i++, r++){ | |
uword c(0); | |
uword const rim = im(r); | |
for(uword k = 0; c <= r and k < ndim; k++){ | |
out.at(rim, im(c++)) = util(F, Fi)(i, j, k); | |
for(uword l = k + 1L; l < ndim; l++, c++) | |
out.at(rim, im(c)) = util(F, Fi)(i, j, k, l); | |
} | |
} | |
} else { | |
uword r(0); | |
for(uword j = 0; j < ndim; j++) | |
for(uword i = j; i < ndim; i++, r++){ | |
uword c(0); | |
for(uword k = 0; c <= r and k < ndim; k++){ | |
out.at(r, c++) = util(F, Fi)(i, j, k); | |
for(uword l = k + 1L; l < ndim; l++, c++) | |
out.at(r, c) = util(F, Fi)(i, j, k, l); | |
} | |
} | |
} | |
return out; | |
} | |
/*** R | |
options(digits = 4) | |
require(matrixcalc) | |
set.seed(2349025) | |
n <- 10 | |
Z <- drop(rWishart(1, 2 * n, diag(n))) | |
##### | |
# simple R-version of the derivative of the Cholesky Factor | |
fn <- function(xin){ | |
x <- matrix(nr = n, nc = n) | |
x[lower.tri(x, TRUE)] <- xin | |
x[upper.tri(x)] <- t(x)[upper.tri(x)] | |
t(chol(x)) | |
} | |
dchol_R <- function(Z){ | |
X <- fn(Z[lower.tri(Z, TRUE)]) | |
L <- elimination.matrix(n) | |
d <- L %*% (diag(n^2) + commutation.matrix(r = n)) %*% | |
tcrossprod(X %x% diag(n), L) | |
solve(d) | |
} | |
# check function | |
library(numDeriv) | |
d <- dchol_R(Z) | |
jac <- jacobian(fn, Z[lower.tri(Z, TRUE)]) | |
keep <- lower.tri(Z, TRUE) | |
all.equal(jac[keep, ], d) | |
#R> [1] TRUE | |
##### | |
# C++ version | |
all.equal(d, dchol(Z)) | |
#R> [1] TRUE | |
##### | |
# same function but with the upper triangular parts | |
fnU <- function(xin){ | |
x <- matrix(nr = n, nc = n) | |
x[upper.tri(x, TRUE)] <- xin | |
x[lower.tri(x)] <- t(x)[lower.tri(x)] | |
chol(x) | |
} | |
jac <- jacobian(fnU, Z[upper.tri(Z, TRUE)]) | |
keep <- upper.tri(Z, TRUE) | |
all.equal(jac[keep, ], dchol(Z, upper = TRUE)) | |
#R> [1] TRUE | |
##### | |
# benchmarks | |
microbenchmark::microbenchmark( | |
R = dchol_R(Z), `C++` = dchol(Z), `C++ (upper)` = dchol(Z, upper = TRUE), | |
times = 1000) | |
#R> Unit: microseconds | |
#R> expr min lq mean median uq max neval | |
#R> R 11233.06 11651.94 12810.15 12959.27 13388.65 44049.72 1000 | |
#R> C++ 14.33 15.68 21.67 16.88 30.17 99.71 1000 | |
#R> C++ (upper) 16.09 18.38 26.83 21.55 35.07 1467.63 1000 | |
# the R function is mainly slow due to the use of the matrixcalc package | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment