Erik-Jan van Kesteren
In this document, I create a small R
function for computing the
Kullback-Leibler divergence between two (multivariate) normal (Gaussian)
distributions. The general formula for the divergence is as follows:
Of course, a function already exists for computing this divergence in
R
. Specifically, it exists as
rags2ridges::KLdiv()
.
However, when installing the {rags2ridges}
package, several
dependencies need to be pulled from bioconductor as they are not on
CRAN, and the package in general has a lot of functionality we might not
need if we just want to compute the KL-divergence.
The goal of this post is thus to create a light-weight, performant
version of the KL-divergence function which is fast, has low memory
requirements, and no dependencies beyond what’s included in base R
.
Below, I show how I achieved a big speedup compared to the existing
KLdiv
function.
We will be using the following example inputs:
set.seed(45)
# we use 25-dimensional normal distributions
P <- 25
# mean vectors
mean_0 <- rnorm(P)
mean_1 <- rnorm(P)
# covariance matrices
S_0 <- rWishart(1, P*2, diag(P))[,,1]
S_1 <- rWishart(1, P*2, diag(P))[,,1]
The existing function returns the following value for
rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0)
[1] 15.40342
The first two terms in our equation are the log-determinants of the covariance matrices:
There are several ways to compute a log-determinant in R
; here, we
benchmark three methods, once on the 25-dimensional S_1
we created
earlier, and another time on a random 1000-dimensional covariance
matrix:
bench::mark(
logdet = log(det(S_1)),
determ = determinant(S_1)$modulus,
cholesky = 2*sum(log(diag(chol(S_1)))),
check = FALSE
)
# A tibble: 3 × 6
expression min median `itr/sec` mem_alloc `gc/sec`
<bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl>
1 logdet 19µs 21.3µs 43297. 4.93KB 4.33
2 determ 15.9µs 17.4µs 52696. 4.93KB 10.5
3 cholesky 22µs 23.8µs 39327. 11.98KB 7.87
set.seed(45)
S1000 <- rWishart(1, 1000, diag(1000))[,,1]
bench::mark(
logdet = log(det(S1000)),
determ = determinant(S1000)$modulus,
cholesky = 2*sum(log(diag(chol(S1000)))),
check = FALSE
)
# A tibble: 3 × 6
expression min median `itr/sec` mem_alloc `gc/sec`
<bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl>
1 logdet 231ms 232ms 4.32 7.63MB 2.16
2 determ 230ms 230ms 4.35 7.63MB 8.70
3 cholesky 200ms 201ms 4.97 7.65MB 2.48
Generally, determinant()
performs slightly better than log(det())
.
For low-dimensional distributions, the cholesky
version is worst, but
it is actually fastest for the 1000-dimensional distribution. This is
especially interesting as we may be able to reuse this cholesky
decomposition to speed up later steps.
The second difficult term to compute is the trace part:
Note that here we need the inverse of lavaan
package, from the code
here,
with the additional knowledge that covariance matrices are symmetric.
bench::mark(
naive = sum(diag(solve(S_1) %*% S_0)),
solve = sum(diag(solve(S_1, S_0))),
lavaan = sum(solve(S_1) * S_0)
)
# A tibble: 3 × 6
expression min median `itr/sec` mem_alloc `gc/sec`
<bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl>
1 naive 64.7µs 69µs 14011. 21KB 12.7
2 solve 48µs 53.6µs 17565. 11.2KB 11.1
3 lavaan 49.1µs 52µs 18438. 20.6KB 10.4
Here, the naive method is definitely the worst, and the other two
methods are almost at the same level. Again, the last identity is
especially interesting, because we want to precompute
Note that if we assume that we have the cholesky decomposition of
S_1_c <- chol(S_1)
bench::mark(
naive = sum(diag(chol2inv(S_1_c) %*% S_0)),
solve = sum(diag(forwardsolve(S_1_c, backsolve(S_1_c, S_0, transpose = TRUE), upper.tri = TRUE))),
lavaan = sum(chol2inv(S_1_c) * S_0)
)
# A tibble: 3 × 6
expression min median `itr/sec` mem_alloc `gc/sec`
<bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl>
1 naive 23.8µs 26µs 34524. 10.34KB 6.91
2 solve 35µs 36.9µs 26524. 26KB 7.96
3 lavaan 10.1µs 10.7µs 89282. 9.86KB 26.8
The last computationally intensive part of the equation is the following quadratic form:
Here, we first compute the mean differences, which we call delta
.
Then, we again need the inverse of the covariance matrix of
delta <- mean_1 - mean_0
Omega_1 <- chol2inv(S_1_c)
bench::mark(
naive = t(delta) %*% Omega_1 %*% delta,
cross = crossprod(delta, Omega_1 %*% delta),
cross2 = crossprod(delta, crossprod(Omega_1, delta))
)
# A tibble: 3 × 6
expression min median `itr/sec` mem_alloc `gc/sec`
<bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl>
1 naive 6.7µs 7.3µs 129531. 496B 0
2 cross 2.4µs 2.7µs 357982. 248B 0
3 cross2 2.9µs 3.2µs 302875. 248B 30.3
The second version, using only one crossprod()
is the fastest here.
It’s actually quite significantly faster than the naïve implementation
with standard matrix products.
Combining all the above sections, we can make two versions of the
KL-divergence function. One which precomputes and uses the Cholesky
decomposition for
kldiv_base <- function(mu_p, mu_q, Sigma_p, Sigma_q) {
logdet_p <- determinant(Sigma_p)$modulus
logdet_q <- determinant(Sigma_q)$modulus
d <- length(mu_p)
Omega_q <- solve(Sigma_q)
trace <- sum(Omega_q * Sigma_p)
delta <- mu_q - mu_p
quad <- crossprod(delta, Omega_q %*% delta)
return(c(logdet_q - logdet_p - d + trace + quad) / 2)
}
kldiv_chol <- function(mu_p, mu_q, Sigma_p, Sigma_q) {
chol_q <- chol(Sigma_q)
logdet_p <- determinant(Sigma_p)$modulus
logdet_q <- 2*sum(log(diag(chol_q)))
d <- length(mu_p)
Omega_q <- chol2inv(chol_q)
trace <- sum(Omega_q * Sigma_p)
delta <- mu_q - mu_p
quad <- crossprod(delta, Omega_q %*% delta)
return(c(logdet_q - logdet_p - d + trace + quad) / 2)
}
bench::mark(
base = kldiv_base(mean_0, mean_1, S_0, S_1),
chol = kldiv_chol(mean_0, mean_1, S_0, S_1)
)
# A tibble: 2 × 6
expression min median `itr/sec` mem_alloc `gc/sec`
<bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl>
1 base 87.9µs 92.8µs 10474. 100KB 8.30
2 chol 54.1µs 59.7µs 16409. 103KB 10.5
As you can see, the cholesky version is quite a bit (>1/3) faster. But does it also give the correct answer?
kldiv_chol(mean_0, mean_1, S_0, S_1)
[1] 15.40342
It does! So let’s use that version to compare to the original, existing
rags2ridges::KLdiv()
function.
bench::mark(
ours = kldiv_chol(mean_0, mean_1, S_0, S_1),
theirs = rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0)
)
# A tibble: 2 × 6
expression min median `itr/sec` mem_alloc `gc/sec`
<bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl>
1 ours 55.5µs 58.8µs 16151. 20.7KB 10.5
2 theirs 788.8µs 822.5µs 1203. 105.9KB 4.09
We’ve achieved quite a dramatic speedup, between 14 and 15x. Additionally, we have allocated 5x less memory. This speedup is maintained for 2-dimensional distributions:
# generate parameters
set.seed(45)
P <- 2
mean_0 <- rnorm(P)
mean_1 <- rnorm(P)
S_0 <- rWishart(1, P*2, diag(P))[,,1]
S_1 <- rWishart(1, P*2, diag(P))[,,1]
bench::mark(
ours = kldiv_chol(mean_0, mean_1, S_0, S_1),
theirs = rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0)
)
# A tibble: 2 × 6
expression min median `itr/sec` mem_alloc `gc/sec`
<bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl>
1 ours 34.4µs 36µs 26593. 0B 2.66
2 theirs 541.4µs 566µs 1705. 0B 6.18
And a positive side-effect of the new function and its memory efficiency is that it also works for larger 250-dimensional distributions, unlike the old function:
# generate parameters
set.seed(45)
P <- 250
mean_0 <- rnorm(P)
mean_1 <- rnorm(P)
S_0 <- rWishart(1, P*2, diag(P))[,,1]
S_1 <- rWishart(1, P*2, diag(P))[,,1]
kldiv_chol(mean_0, mean_1, S_0, S_1)
[1] 128.2745
rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0)
[1] NaN
What remains is to tidy up the function, make it more robust, and to document it properly. Allow me to present the fast and furious Gaussian Kullback-Leibler divergence:
#' Kullback-leibler divergence between two Gaussians
#'
#' This function computes $D_{KL}(p||q)$, where $p(x)$
#' and $q(x)$ are two multivariate normal distributions.
#'
#' @param mu_p the mean vector of p
#' @param mu_q the mean vector of q
#' @param Sigma_p the covariance matrix of p
#' @param Sigma_q the covariance matrix of q
#'
#' @return Kullback-leibler divergence from p to q
#' (numeric scalar)
kldiv <- function(mu_p, mu_q, Sigma_p, Sigma_q) {
chol_q <- chol(Sigma_q)
logdet_p <- determinant(matrix(Sigma_p))$modulus
logdet_q <- 2*sum(log(diag(chol_q)))
d <- length(mu_p)
Omega_q <- chol2inv(chol_q)
trace <- sum(Omega_q * Sigma_p)
delta <- mu_q - mu_p
quad <- crossprod(delta, Omega_q %*% delta)
return(c(logdet_q - logdet_p - d + trace + quad) / 2)
}