Python implementation of the Kullback-Leibler divergence estimator described in this paper. It relies on scikit-learn for k-NN.
Last active
April 14, 2024 14:05
-
-
Save javipus/24d07319fac761c65686198dd9897ebf to your computer and use it in GitHub Desktop.
Kullback-Leibler Divergence Estimator.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__pycache__/* |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
from __future__ import division | |
import os, sys, time, copy, re | |
sys.dont_write_bytecode = True | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from sklearn.neighbors import NearestNeighbors as kNN | |
def klDivergence(x, y, logBase = 2, returnk = False, **kwds): | |
# TODO the estimator converges a.s. to the actual KL divergence, but the convergence rate looks terrible | |
# TODO simple tests with 1-D gaussians suggest (this implementation of) the estimator is biased | |
# TODO does k-NN with k>1 density estimation improve results (cost-effectively, given it will take more computation time)? | |
""" | |
Estimate KL divergence D(P||Q) between unknown distributions P and Q using iid samples X_i~P and Y_i~Q. | |
The estimator is the one developed in http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.422.5121&rep=rep1&type=pdf | |
@param x: List of samples from *true* distribution, P. | |
@param y: List of samples from *approximate* distribution, Q. | |
@param logBase: Base of the logarithm to compute the KL divergence (changes the units in which entropy is expressed). | |
@param returnk: If True, returns k used for k-NN estimation as second argument. | |
@param kwds: Keyword arguments to be passed to sklearn.neighbors.NearestNeighbors constructor. | |
""" | |
n, m = len(x), len(y) | |
_log = lambda x, b: np.log2(x) / np.log2(b) if x != 0 else 0 | |
_log = np.vectorize(_log) | |
if not hasattr(x[0], '__len__'): # 1-D data | |
x, y = map(lambda z: np.array(z).reshape(-1, 1), (x, y)) | |
d = x.shape[1] | |
if d != y.shape[1]: | |
raise Exception('Dimension mismatch: X{} != Y{}'.format(x.shape, y.shape)) | |
# nnDistX is zero if a point is sampled >= twice and that's a problem because it goes in a denominator | |
# Remember what you're trying to do here: calculate the distance to the 1-NN so that the empirical estimate of | |
# the pdf is ~1/dist_1NN. Then it would make sense to find the smallest k for which dist_kNN > 0 and estimate | |
# the pdf as ~k/dist_kNN - but there's no way of knowing this upfront so you have to loop through the particles | |
# and see which ones have been sampled more than once and ugh | |
# And that's the reason for that while loop | |
k = 1 # start with 1-NN | |
while True: | |
knnX = kNN(n_neighbors = k+1, **kwds).fit(x) | |
nnDistX = knnX.kneighbors(x)[0][:, k] | |
if not nnDistX.all(): | |
k += 1 | |
else: | |
break | |
knnY = kNN(n_neighbors = k, **kwds).fit(y) | |
nnDistY = knnY.kneighbors(x)[0][:, k-1] | |
kl = (d/n) * sum(_log(nnDistY/nnDistX, logBase)) + _log((m/(n-1)), logBase) | |
if returnk: | |
return kl, k | |
else: | |
return kl | |
def kl_th(p, q, logBase = 2): | |
""" | |
Analytical formula for the Kullback-Leibler divergence in simple cases. | |
@param p, q: Dictionary containing model specification with keys: | |
- 'model': Only 'normal' for now - TODO: include exponentials, etc. | |
- Other keys: model parameters | |
TODO: This argument should be something standard. Maybe pass an instance of scipy.stats.<distribution_name>? | |
@param logBase: Entropy measured in: | |
- Bits if log2. | |
- Nats if ln. | |
- What's the name for log10 again? | |
""" | |
p_type, q_type = p['model'], q['model'] | |
tr = np.trace | |
inv = np.linalg.inv | |
det = np.linalg.det | |
_log = lambda x, b: np.log2(x) / np.log2(b) | |
if p_type == q_type == 'normal': | |
mu_p, mu_q = p['mu'], q['mu'] | |
sigma_p, sigma_q = p['sigma'], q['sigma'] | |
dmu = mu_q - mu_p | |
d = len(dmu) | |
if d > 1: | |
kl = .5 * (tr(inv(sigma_q) * sigma_p) + np.dot(dmu, np.dot(inv(sigma_q), dmu)) - d + _log(det(sigma_q)/det(sigma_p), logBase)) | |
else: | |
kl = .5 * (sigma_p/sigma_q + (dmu**2)/sigma_q - d + _log(sigma_q/sigma_p, logBase)) | |
elif p_type == 'normal' and q_type == 'exponential': | |
raise Exception('Working on it!') | |
elif p_type == 'exponential' and q_type == 'normal': | |
raise Exception('Working on it!') | |
else: | |
raise Exception('Distributions {} and/or {} unknown!'.format(p_type, q_type)) | |
return kl |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from warnings import warn | |
import os, re, time, sys, copy | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from kl_div_estimator import klDivergence, kl_th | |
kl = klDivergence | |
def runToy(d = 1): | |
p = { | |
'model': 'normal', | |
'mu': np.zeros(d), | |
'sigma': 5*np.eye(d) | |
} | |
q0 = { | |
'model': 'normal', | |
'mu': np.array([0]), | |
'sigma': np.array([5]) | |
} | |
q1 = { | |
'model': 'normal', | |
'mu': np.array([0]), | |
'sigma': np.array([10]) | |
} | |
n = 1000 | |
m = 1000 | |
kl_emp = [] | |
for k in range(m): | |
if d == 1: | |
x = p['sigma'][0] ** (.5) * np.random.randn(n) + p['mu'] | |
y = q0['sigma'][0] ** (.5) * np.random.randn(n) + q0['mu'] | |
z = q1['sigma'][0] ** (.5) * np.random.randn(n) + q1['mu'] | |
else: | |
x = np.random.multivariate_normal(p['mu'], p['sigma'], n) | |
y = np.random.multivariate_normal(q0['mu'], q0['sigma'], n) | |
z = np.random.multivariate_normal(q1['mu'], q1['sigma'], n) | |
kl_emp.append(list(map(lambda yy: kl(x, yy), (y, z)))) | |
kl_emp = np.array(kl_emp) | |
kl_real = list(map(lambda q: kl_th(p, q), (q0, q1))) | |
plt.hist(kl_emp[:, 0], color = 'b', alpha = .25) | |
plt.hist(kl_emp[:, 1], color = 'r', alpha = .25) | |
plt.axvline(np.mean(kl_emp[:, 0]), ls = '--', color = 'b', lw = 3, label = None) | |
plt.axvline(np.mean(kl_emp[:, 1]), ls = '--', color = 'r', lw = 3, label = None) | |
plt.axvline(kl_real[0], c = 'b', lw = 3, label = r'D(P||$Q_0$)') | |
plt.axvline(kl_real[1], c = 'r', lw = 3, label = r'D(P||$Q_1$)') | |
plt.xlabel('KL Divergence') | |
plt.ylabel('Count') | |
plt.legend() | |
plt.tight_layout() | |
plt.show() | |
return kl_real, kl_emp | |
if __name__ == '__main__': | |
klth, kle = runToy() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment