Last active
September 28, 2022 07:21
-
-
Save swayson/86c296aa354a555536e6765bbe726ff7 to your computer and use it in GitHub Desktop.
Numpy and scipy ways to calculate KL Divergence.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Specifically, the Kullback–Leibler divergence from Q to P, denoted DKL(P‖Q), is | |
a measure of the information gained when one revises one's beliefs from the | |
prior probability distribution Q to the posterior probability distribution P. In | |
other words, it is the amount of information lost when Q is used to approximate | |
P. | |
""" | |
import numpy as np | |
from scipy.stats import entropy | |
def kl(p, q): | |
"""Kullback-Leibler divergence D(P || Q) for discrete distributions | |
Parameters | |
---------- | |
p, q : array-like, dtype=float, shape=n | |
Discrete probability distributions. | |
""" | |
p = np.asarray(p, dtype=np.float) | |
q = np.asarray(q, dtype=np.float) | |
return np.sum(np.where(p != 0, p * np.log(p / q), 0)) | |
def kl(p, q): | |
"""Kullback-Leibler divergence D(P || Q) for discrete distributions | |
Parameters | |
---------- | |
p, q : array-like, dtype=float, shape=n | |
Discrete probability distributions. | |
""" | |
p = np.asarray(p, dtype=np.float) | |
q = np.asarray(q, dtype=np.float) | |
return np.sum(np.where(p != 0, p * np.log(p / q), 0)) | |
p = [0.1, 0.9] | |
q = [0.1, 0.9] | |
assert entropy(p, q) == kl(p, q) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note that
scipy.stats.entropy(pk, qk=None, base=None, axis=0)
does compute KL ifqk
is not None.