-
-
Save zhiyzuo/f80e2b1cfb493a5711330d271a228a3d to your computer and use it in GitHub Desktop.
import numpy as np | |
import scipy as sp | |
def jsd(p, q, base=np.e): | |
''' | |
Implementation of pairwise `jsd` based on | |
https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence | |
''' | |
## convert to np.array | |
p, q = np.asarray(p), np.asarray(q) | |
## normalize p, q to probabilities | |
p, q = p/p.sum(), q/q.sum() | |
m = 1./2*(p + q) | |
return sp.stats.entropy(p,m, base=base)/2. + sp.stats.entropy(q, m, base=base)/2. |
What would be an optimum way to calculate the jsd when there is a large number of probability distribution? Say there is 1 data set with 10K probability distribution. I want to calculate the jsd of each of them with everything other. it will end up being roughly 10K*10K/2 computations. Is there any smart way to do it and avoid so many for loop or distributed processing. anything in numpy that can help?
thanks guys. i've updated the code to do the normalization first.
@manjeetnagi, i'm not really sure about an "optimal" way. if i were you, i would simply use joblib
to parallelize the process.
see :
from scipy.spatial import distance
distance.jensenshannon(a,b)
Just for those who land here looking for jensen shannon distance (using monte carlo integration) between two distributions:
def distributions_js(distribution_p, distribution_q, n_samples=10 ** 5):
# jensen shannon divergence. (Jensen shannon distance is the square root of the divergence)
# all the logarithms are defined as log2 (because of information entrophy)
X = distribution_p.rvs(n_samples)
p_X = distribution_p.pdf(X)
q_X = distribution_q.pdf(X)
log_mix_X = np.log2(p_X + q_X)
Y = distribution_q.rvs(n_samples)
p_Y = distribution_p.pdf(Y)
q_Y = distribution_q.pdf(Y)
log_mix_Y = np.log2(p_Y + q_Y)
return (np.log2(p_X).mean() - (log_mix_X.mean() - np.log2(2))
+ np.log2(q_Y).mean() - (log_mix_Y.mean() - np.log2(2))) / 2
print("should be different:")
print(distributions_js(st.norm(loc=10000), st.norm(loc=0)))
print("should be same:")
print(distributions_js(st.norm(loc=0), st.norm(loc=0)))
A simple fix could be to add
which gives in total