-
-
Save larsmans/3116927 to your computer and use it in GitHub Desktop.
""" | |
Three ways of computing the Hellinger distance between two discrete | |
probability distributions using NumPy and SciPy. | |
""" | |
import numpy as np | |
from scipy.linalg import norm | |
from scipy.spatial.distance import euclidean | |
_SQRT2 = np.sqrt(2) # sqrt(2) with default precision np.float64 | |
def hellinger1(p, q): | |
return norm(np.sqrt(p) - np.sqrt(q)) / _SQRT2 | |
def hellinger2(p, q): | |
return euclidean(np.sqrt(p), np.sqrt(q)) / _SQRT2 | |
def hellinger3(p, q): | |
return np.sqrt(np.sum((np.sqrt(p) - np.sqrt(q)) ** 2)) / _SQRT2 |
@cscorley Probability distributions aren't supposed to ever contain negative numbers. These functions don't do input validation for speed reasons.
In case anyone is wondering, I believe hellinger2
and hellinger3
are faster than hellinger1
. (I had been using hellinger1
in one of my projects until some profiling determined it was a rate-limiting step.) Here is some timing code:
"""
Three ways of computing the Hellinger distance between two discrete
probability distributions using NumPy and SciPy.
"""
import time
import numpy as np
from scipy.linalg import norm
from scipy.spatial.distance import euclidean
_SQRT2 = np.sqrt(2) # sqrt(2) with default precision np.float64
def hellinger1(p, q):
return norm(np.sqrt(p) - np.sqrt(q)) / _SQRT2
def hellinger2(p, q):
return euclidean(np.sqrt(p), np.sqrt(q)) / _SQRT2
def hellinger3(p, q):
return np.sqrt(np.sum((np.sqrt(p) - np.sqrt(q)) ** 2)) / _SQRT2
repeat = 1000000
p = np.array([.05, .05, .1, .1, .2, .2, .3] * repeat)
q = np.array([ 0, 0, 0, .1, .3, .3 ,.3] * repeat)
p /= p.sum()
q /= q.sum()
for hellingerFunction in [hellinger1, hellinger2, hellinger3]:
start = time.time()
hellingerFunction(p=p, q=q)
duration = time.time() - start
print("{} took {} long".format(hellingerFunction.__name__, duration))
Should get something like:
hellinger1 took 0.15966796875 long
hellinger2 took 0.10175800323486328 long
hellinger3 took 0.09865593910217285 long
The difference shrinks for shorter arrays p
and q
, but even if repeat=1
so that p
and q
are of length 7, hellinger3
is still faster.
Hope this is right.
In case anyone is interested, I've implemented Hellinger Distance in Cython as a split criterion for sklearn DecisionTreeClassifier and RandomForestClassifier.
It performs great in my use cases of imbalanced data classification, beats RandomForestClassifier with gini and XGBClassifier.
You are welcome to check it out on https://github.com/EvgeniDubov/hellinger-distance-criterion
To anyone that finds this gist at a later date and you're getting the exception
ValueError: array must not contain infs or NaNs
Make sure that the distributions given to these functions only contain positive values. Otherwise, sqrt is going to cause you pain. Throw them through np.absolute() first if you need to.