Skip to content

Instantly share code, notes, and snippets.

@vankesteren
Last active January 26, 2023 21:54
Show Gist options
  • Save vankesteren/1dd9461d67edfa5a104ad257bd4f5f43 to your computer and use it in GitHub Desktop.
Save vankesteren/1dd9461d67edfa5a104ad257bd4f5f43 to your computer and use it in GitHub Desktop.

Efficient Gaussian Kullback-Leibler divergence in R

Erik-Jan van Kesteren

Introduction

In this document, I create a small R function for computing the Kullback-Leibler divergence between two (multivariate) normal (Gaussian) distributions. The general formula for the divergence is as follows:

$$D_{KL}(p || q) = \int_{-\infty}^{\infty} p(x) \log \left(\frac{p(x)}{q(x)}\right), dx$$ If we assume that both $p(x)$ and $q(x)$ are multivariate normal, then all the difficult integration drops out and we are left with the following formula, adapted from this excellent answer on StackOverflow:

$$D_{KL}(p||q) = \frac{1}{2}\left[\log|\Sigma_q|-\log|\Sigma_p| - d + \text{tr} { \Sigma_q^{-1}\Sigma_p } + (\mu_q - \mu_p)^T \Sigma_q^{-1}(\mu_q - \mu_p)\right] $$ Where $d$ indicates the dimensionality of the distributions, and $|\cdot|$ indicates the determinant.

Of course, a function already exists for computing this divergence in R. Specifically, it exists as rags2ridges::KLdiv(). However, when installing the {rags2ridges} package, several dependencies need to be pulled from bioconductor as they are not on CRAN, and the package in general has a lot of functionality we might not need if we just want to compute the KL-divergence.

The goal of this post is thus to create a light-weight, performant version of the KL-divergence function which is fast, has low memory requirements, and no dependencies beyond what’s included in base R. Below, I show how I achieved a big speedup compared to the existing KLdiv function.

Example inputs

We will be using the following example inputs:

set.seed(45)

# we use 25-dimensional normal distributions
P <- 25

# mean vectors
mean_0 <- rnorm(P)
mean_1 <- rnorm(P)

# covariance matrices
S_0 <- rWishart(1, P*2, diag(P))[,,1]
S_1 <- rWishart(1, P*2, diag(P))[,,1]

The existing function returns the following value for $D_{KL}(p||q)$ with these inputs (note that the ordering is a bit weird; KL-divergence is not symmetric so order is important):

rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0)
[1] 15.40342

Computing log-determinants

The first two terms in our equation are the log-determinants of the covariance matrices:

$$\log|\Sigma_q|-\log|\Sigma_p|$$

There are several ways to compute a log-determinant in R; here, we benchmark three methods, once on the 25-dimensional S_1 we created earlier, and another time on a random 1000-dimensional covariance matrix:

bench::mark(
  logdet = log(det(S_1)),
  determ = determinant(S_1)$modulus,
  cholesky = 2*sum(log(diag(chol(S_1)))),
  check = FALSE
)
# A tibble: 3 × 6
  expression      min   median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
1 logdet         19µs   21.3µs    43297.    4.93KB     4.33
2 determ       15.9µs   17.4µs    52696.    4.93KB    10.5 
3 cholesky       22µs   23.8µs    39327.   11.98KB     7.87
set.seed(45)
S1000 <- rWishart(1, 1000, diag(1000))[,,1]
bench::mark(
  logdet = log(det(S1000)),
  determ = determinant(S1000)$modulus,
  cholesky = 2*sum(log(diag(chol(S1000)))),
  check = FALSE
)
# A tibble: 3 × 6
  expression      min   median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
1 logdet        231ms    232ms      4.32    7.63MB     2.16
2 determ        230ms    230ms      4.35    7.63MB     8.70
3 cholesky      200ms    201ms      4.97    7.65MB     2.48

Generally, determinant() performs slightly better than log(det()). For low-dimensional distributions, the cholesky version is worst, but it is actually fastest for the 1000-dimensional distribution. This is especially interesting as we may be able to reuse this cholesky decomposition to speed up later steps.

Computing the trace

The second difficult term to compute is the trace part:

$$\text{tr} { \Sigma_q^{-1}\Sigma_p }$$

Note that here we need the inverse of $\Sigma_q$, which we will need at a later stage as well. Here, we again have several options for computing the trace. One trick is adapted from the lavaan package, from the code here, with the additional knowledge that covariance matrices are symmetric.

bench::mark(
  naive = sum(diag(solve(S_1) %*% S_0)),
  solve = sum(diag(solve(S_1, S_0))),
  lavaan = sum(solve(S_1) * S_0)
)
# A tibble: 3 × 6
  expression      min   median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
1 naive        64.7µs     69µs    14011.      21KB     12.7
2 solve          48µs   53.6µs    17565.    11.2KB     11.1
3 lavaan       49.1µs     52µs    18438.    20.6KB     10.4

Here, the naive method is definitely the worst, and the other two methods are almost at the same level. Again, the last identity is especially interesting, because we want to precompute $\Sigma_q^{-1}$ anyway for the next part.

Note that if we assume that we have the cholesky decomposition of $\Sigma_q$ precomputed from the previous part, this equation – especially the last one – becomes even faster:

S_1_c <- chol(S_1)

bench::mark(
  naive = sum(diag(chol2inv(S_1_c) %*% S_0)),
  solve = sum(diag(forwardsolve(S_1_c, backsolve(S_1_c, S_0, transpose = TRUE), upper.tri = TRUE))),
  lavaan = sum(chol2inv(S_1_c) * S_0)
)
# A tibble: 3 × 6
  expression      min   median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
1 naive        23.8µs     26µs    34524.   10.34KB     6.91
2 solve          35µs   36.9µs    26524.      26KB     7.96
3 lavaan       10.1µs   10.7µs    89282.    9.86KB    26.8 

Quadratic form

The last computationally intensive part of the equation is the following quadratic form:

$$(\mu_q - \mu_p)^T \Sigma_q^{-1}(\mu_q - \mu_p)$$

Here, we first compute the mean differences, which we call delta. Then, we again need the inverse of the covariance matrix of $q(x)$, so we will assume that it is pre-computed from the previous step.

delta <- mean_1 - mean_0
Omega_1 <- chol2inv(S_1_c)

bench::mark(
  naive = t(delta) %*% Omega_1 %*% delta,
  cross = crossprod(delta, Omega_1 %*% delta),
  cross2 = crossprod(delta, crossprod(Omega_1, delta))
)
# A tibble: 3 × 6
  expression      min   median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
1 naive         6.7µs    7.3µs   129531.      496B      0  
2 cross         2.4µs    2.7µs   357982.      248B      0  
3 cross2        2.9µs    3.2µs   302875.      248B     30.3

The second version, using only one crossprod() is the fastest here. It’s actually quite significantly faster than the naïve implementation with standard matrix products.

Putting it all together

Combining all the above sections, we can make two versions of the KL-divergence function. One which precomputes and uses the Cholesky decomposition for $\Sigma_q$, and one which does not. We can compare them for our example distributions:

kldiv_base <- function(mu_p, mu_q, Sigma_p, Sigma_q) {
  logdet_p <- determinant(Sigma_p)$modulus
  logdet_q <- determinant(Sigma_q)$modulus
  d <- length(mu_p)
  Omega_q <- solve(Sigma_q)
  trace <- sum(Omega_q * Sigma_p)
  delta <- mu_q - mu_p
  quad <- crossprod(delta, Omega_q %*% delta)
  return(c(logdet_q - logdet_p - d + trace + quad) / 2)
}

kldiv_chol <- function(mu_p, mu_q, Sigma_p, Sigma_q) {
  chol_q <- chol(Sigma_q)
  logdet_p <- determinant(Sigma_p)$modulus
  logdet_q <- 2*sum(log(diag(chol_q)))
  d <- length(mu_p)
  Omega_q <- chol2inv(chol_q)
  trace <- sum(Omega_q * Sigma_p)
  delta <- mu_q - mu_p
  quad <- crossprod(delta, Omega_q %*% delta)
  return(c(logdet_q - logdet_p - d + trace + quad) / 2)
}

bench::mark(
  base = kldiv_base(mean_0, mean_1, S_0, S_1),
  chol = kldiv_chol(mean_0, mean_1, S_0, S_1)
)
# A tibble: 2 × 6
  expression      min   median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
1 base         87.9µs   92.8µs    10474.     100KB     8.30
2 chol         54.1µs   59.7µs    16409.     103KB    10.5 

As you can see, the cholesky version is quite a bit (>1/3) faster. But does it also give the correct answer?

kldiv_chol(mean_0, mean_1, S_0, S_1)
[1] 15.40342

It does! So let’s use that version to compare to the original, existing rags2ridges::KLdiv() function.

bench::mark(
  ours = kldiv_chol(mean_0, mean_1, S_0, S_1),
  theirs = rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0)
)
# A tibble: 2 × 6
  expression      min   median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
1 ours         55.5µs   58.8µs    16151.    20.7KB    10.5 
2 theirs      788.8µs  822.5µs     1203.   105.9KB     4.09

We’ve achieved quite a dramatic speedup, between 14 and 15x. Additionally, we have allocated 5x less memory. This speedup is maintained for 2-dimensional distributions:

# generate parameters
set.seed(45)
P <- 2
mean_0 <- rnorm(P)
mean_1 <- rnorm(P)
S_0 <- rWishart(1, P*2, diag(P))[,,1]
S_1 <- rWishart(1, P*2, diag(P))[,,1]

bench::mark(
  ours = kldiv_chol(mean_0, mean_1, S_0, S_1),
  theirs = rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0)
)
# A tibble: 2 × 6
  expression      min   median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
1 ours         34.4µs     36µs    26593.        0B     2.66
2 theirs      541.4µs    566µs     1705.        0B     6.18

And a positive side-effect of the new function and its memory efficiency is that it also works for larger 250-dimensional distributions, unlike the old function:

# generate parameters
set.seed(45)
P <- 250
mean_0 <- rnorm(P)
mean_1 <- rnorm(P)
S_0 <- rWishart(1, P*2, diag(P))[,,1]
S_1 <- rWishart(1, P*2, diag(P))[,,1]

kldiv_chol(mean_0, mean_1, S_0, S_1)
[1] 128.2745
rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0)
[1] NaN

Clean-up and document

What remains is to tidy up the function, make it more robust, and to document it properly. Allow me to present the fast and furious Gaussian Kullback-Leibler divergence:

#' Kullback-leibler divergence between two Gaussians
#' 
#' This function computes $D_{KL}(p||q)$, where $p(x)$ 
#' and $q(x)$ are two multivariate normal distributions.
#' 
#' @param mu_p the mean vector of p
#' @param mu_q the mean vector of q
#' @param Sigma_p the covariance matrix of p
#' @param Sigma_q the covariance matrix of q
#' 
#' @return Kullback-leibler divergence from p to q 
#' (numeric scalar)
kldiv <- function(mu_p, mu_q, Sigma_p, Sigma_q) {
  chol_q   <- chol(Sigma_q)
  logdet_p <- determinant(matrix(Sigma_p))$modulus
  logdet_q <- 2*sum(log(diag(chol_q)))
  d        <- length(mu_p)
  Omega_q  <- chol2inv(chol_q)
  trace    <- sum(Omega_q * Sigma_p)
  delta    <- mu_q - mu_p
  quad     <- crossprod(delta, Omega_q %*% delta)
  return(c(logdet_q - logdet_p - d + trace + quad) / 2)
}
---
title: "Efficient Gaussian Kullback-Leibler divergence in R"
author: "Erik-Jan van Kesteren"
format: gfm
---
## Introduction
In this document, I create a small `R` function for computing the Kullback-Leibler divergence between two (multivariate) normal (Gaussian) distributions. The general formula for the divergence is as follows:
$$D_{KL}(p || q) = \int_{-\infty}^{\infty} p(x) \log \left(\frac{p(x)}{q(x)}\right)\, dx$$
If we assume that both $p(x)$ and $q(x)$ are multivariate normal, then all the difficult integration drops out and we are left with the following
formula, adapted from [this excellent answer on StackOverflow](https://stats.stackexchange.com/a/60699/116878):
$$D_{KL}(p||q) = \frac{1}{2}\left[\log|\Sigma_q|-\log|\Sigma_p| - d + \text{tr} \{ \Sigma_q^{-1}\Sigma_p \} + (\mu_q - \mu_p)^T \Sigma_q^{-1}(\mu_q - \mu_p)\right] $$
Where $d$ indicates the dimensionality of the distributions, and $|\cdot|$ indicates the determinant.
Of course, a function already exists for computing this divergence in `R`. Specifically, it exists as [`rags2ridges::KLdiv()`](https://search.r-project.org/CRAN/refmans/rags2ridges/html/KLdiv.html). However, when installing the `{rags2ridges}` package, several dependencies need to be pulled from bioconductor as they are not on CRAN, and the package in general has a lot of functionality we might not need if we just want to compute the KL-divergence.
The goal of this post is thus to create a light-weight, performant version of the KL-divergence function which is fast, has low memory requirements, and no dependencies beyond what's included in base `R`. Below, I show how I achieved a big speedup compared to the existing `KLdiv` function.
## Example inputs
We will be using the following example inputs:
```{r}
#| label: example inputs
set.seed(45)
# we use 25-dimensional normal distributions
P <- 25
# mean vectors
mean_0 <- rnorm(P)
mean_1 <- rnorm(P)
# covariance matrices
S_0 <- rWishart(1, P*2, diag(P))[,,1]
S_1 <- rWishart(1, P*2, diag(P))[,,1]
```
The existing function returns the following value for $D_{KL}(p||q)$ with these inputs (note that the ordering is a bit weird; KL-divergence is not symmetric so order is important):
```{r}
#| label: existing kldiv
rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0)
```
## Computing log-determinants
The first two terms in our equation are the log-determinants of the covariance matrices:
$$\log|\Sigma_q|-\log|\Sigma_p|$$
There are several ways to compute a log-determinant in `R`; here, we benchmark three methods, once on the 25-dimensional `S_1` we created earlier, and another time on a random 1000-dimensional covariance matrix:
```{r}
#| label: log-determinants
bench::mark(
logdet = log(det(S_1)),
determ = determinant(S_1)$modulus,
cholesky = 2*sum(log(diag(chol(S_1)))),
check = FALSE
)
set.seed(45)
S1000 <- rWishart(1, 1000, diag(1000))[,,1]
bench::mark(
logdet = log(det(S1000)),
determ = determinant(S1000)$modulus,
cholesky = 2*sum(log(diag(chol(S1000)))),
check = FALSE
)
```
Generally, `determinant()` performs slightly better than `log(det())`. For low-dimensional distributions, the `cholesky` version is worst, but it is actually fastest for the 1000-dimensional distribution. This is especially interesting as we may be able to reuse this cholesky decomposition to speed up later steps.
## Computing the trace
The second difficult term to compute is the trace part:
$$\text{tr} \{ \Sigma_q^{-1}\Sigma_p \}$$
Note that here we need the inverse of $\Sigma_q$, which we will need at a later stage as well. Here, we again have several options for computing the trace. One trick is adapted from the [`lavaan`](https://lavaan.org) package, from the code [here](https://github.com/yrosseel/lavaan/blob/4eb699f1300d22657041f1d04965a8cb6d89811f/R/lav_matrix.R#L970-L974), with the additional knowledge that covariance matrices are symmetric.
```{r}
#| label: trace
bench::mark(
naive = sum(diag(solve(S_1) %*% S_0)),
solve = sum(diag(solve(S_1, S_0))),
lavaan = sum(solve(S_1) * S_0)
)
```
Here, the naive method is definitely the worst, and the other two methods are almost at the same level. Again, the last identity is especially interesting, because we want to precompute $\Sigma_q^{-1}$ anyway for the next part.
Note that if we assume that we have the cholesky decomposition of $\Sigma_q$ precomputed from the previous part, this equation -- especially the last one -- becomes even faster:
```{r}
#| label: trace chol
S_1_c <- chol(S_1)
bench::mark(
naive = sum(diag(chol2inv(S_1_c) %*% S_0)),
solve = sum(diag(forwardsolve(S_1_c, backsolve(S_1_c, S_0, transpose = TRUE), upper.tri = TRUE))),
lavaan = sum(chol2inv(S_1_c) * S_0)
)
```
## Quadratic form
The last computationally intensive part of the equation is the following quadratic form:
$$(\mu_q - \mu_p)^T \Sigma_q^{-1}(\mu_q - \mu_p)$$
Here, we first compute the mean differences, which we call `delta`. Then, we again need the inverse of the covariance matrix of $q(x)$, so we will assume that it is pre-computed from the previous step.
```{r}
#| label: quadratic form
delta <- mean_1 - mean_0
Omega_1 <- chol2inv(S_1_c)
bench::mark(
naive = t(delta) %*% Omega_1 %*% delta,
cross = crossprod(delta, Omega_1 %*% delta),
cross2 = crossprod(delta, crossprod(Omega_1, delta))
)
```
The second version, using only one `crossprod()` is the fastest here. It's actually quite significantly faster than the naïve implementation with standard matrix products.
## Putting it all together
Combining all the above sections, we can make two versions of the KL-divergence function. One which precomputes and uses the Cholesky decomposition for $\Sigma_q$, and one which does not. We can compare them for our example distributions:
```{r}
#| label: kldiv
kldiv_base <- function(mu_p, mu_q, Sigma_p, Sigma_q) {
logdet_p <- determinant(Sigma_p)$modulus
logdet_q <- determinant(Sigma_q)$modulus
d <- length(mu_p)
Omega_q <- solve(Sigma_q)
trace <- sum(Omega_q * Sigma_p)
delta <- mu_q - mu_p
quad <- crossprod(delta, Omega_q %*% delta)
return(c(logdet_q - logdet_p - d + trace + quad) / 2)
}
kldiv_chol <- function(mu_p, mu_q, Sigma_p, Sigma_q) {
chol_q <- chol(Sigma_q)
logdet_p <- determinant(Sigma_p)$modulus
logdet_q <- 2*sum(log(diag(chol_q)))
d <- length(mu_p)
Omega_q <- chol2inv(chol_q)
trace <- sum(Omega_q * Sigma_p)
delta <- mu_q - mu_p
quad <- crossprod(delta, Omega_q %*% delta)
return(c(logdet_q - logdet_p - d + trace + quad) / 2)
}
bench::mark(
base = kldiv_base(mean_0, mean_1, S_0, S_1),
chol = kldiv_chol(mean_0, mean_1, S_0, S_1)
)
```
As you can see, the cholesky version is quite a bit (>1/3) faster. But does it also give the correct answer?
```{r}
#| label: answer
kldiv_chol(mean_0, mean_1, S_0, S_1)
```
It does! So let's use that version to compare to the original, existing `rags2ridges::KLdiv()` function.
```{r}
#| label: comparison
bench::mark(
ours = kldiv_chol(mean_0, mean_1, S_0, S_1),
theirs = rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0)
)
```
We've achieved quite a dramatic speedup, between 14 and 15x. Additionally, we have allocated 5x less memory. This speedup is maintained for 2-dimensional distributions:
```{r}
#| label: 2d comparison
# generate parameters
set.seed(45)
P <- 2
mean_0 <- rnorm(P)
mean_1 <- rnorm(P)
S_0 <- rWishart(1, P*2, diag(P))[,,1]
S_1 <- rWishart(1, P*2, diag(P))[,,1]
bench::mark(
ours = kldiv_chol(mean_0, mean_1, S_0, S_1),
theirs = rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0)
)
```
And a positive side-effect of the new function and its memory efficiency is that it also works for larger 250-dimensional distributions, unlike the old function:
```{r}
#| label: 250d comparison
# generate parameters
set.seed(45)
P <- 250
mean_0 <- rnorm(P)
mean_1 <- rnorm(P)
S_0 <- rWishart(1, P*2, diag(P))[,,1]
S_1 <- rWishart(1, P*2, diag(P))[,,1]
kldiv_chol(mean_0, mean_1, S_0, S_1)
rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0)
```
## Clean-up and document
What remains is to tidy up the function, make it more robust, and to document it properly. Allow me to present the fast and furious Gaussian Kullback-Leibler divergence:
```{r}
#| label: KL-divergence
#' Kullback-leibler divergence between two Gaussians
#'
#' This function computes $D_{KL}(p||q)$, where $p(x)$
#' and $q(x)$ are two multivariate normal distributions.
#'
#' @param mu_p the mean vector of p
#' @param mu_q the mean vector of q
#' @param Sigma_p the covariance matrix of p
#' @param Sigma_q the covariance matrix of q
#'
#' @return Kullback-leibler divergence from p to q
#' (numeric scalar)
kldiv <- function(mu_p, mu_q, Sigma_p, Sigma_q) {
chol_q <- chol(Sigma_q)
logdet_p <- determinant(matrix(Sigma_p))$modulus
logdet_q <- 2*sum(log(diag(chol_q)))
d <- length(mu_p)
Omega_q <- chol2inv(chol_q)
trace <- sum(Omega_q * Sigma_p)
delta <- mu_q - mu_p
quad <- crossprod(delta, Omega_q %*% delta)
return(c(logdet_q - logdet_p - d + trace + quad) / 2)
}
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment