Skip to content

Instantly share code, notes, and snippets.

@dela3499
Last active January 7, 2016 03:13
Show Gist options
  • Save dela3499/a2a76c61dc5b86062a02 to your computer and use it in GitHub Desktop.
Save dela3499/a2a76c61dc5b86062a02 to your computer and use it in GitHub Desktop.
Refactoring the Kolmogorov-Smirnov Test implementation in Scipy
# Wikipedia: https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test#Two-sample_Kolmogorov.E2.80.93Smirnov_test
# Scipy source: https://github.com/scipy/scipy/blob/v0.15.1/scipy/stats/stats.py#L3966
def ks_2samp(data1, data2):
data1, data2 = map(asarray, (data1, data2))
n1 = data1.shape[0] # n1 and n2 reassigned below, and can be removed here.
n2 = data2.shape[0] # maybe this is used for error message, in case inputs aren't lists. But len provides a good error message.
n1 = len(data1) # Avoid repetition with list comprehension or map
n2 = len(data2)
data1 = np.sort(data1) # Avoid repetition with list comprehension or map
data2 = np.sort(data2) # and np.sort turns lists into arrays (so asarray is unnecessary)
# Also, don't mutate input variable. Give it a new name.
data_all = np.concatenate([data1,data2])
cdf1 = np.searchsorted(data1,data_all,side='right')/(1.0*n1) # Avoid repetition. To coerce length to float, use float function, not multiplication by a float
cdf2 = (np.searchsorted(data2,data_all,side='right'))/(1.0*n2)
d = np.max(np.absolute(cdf1-cdf2))
return d
def ks_2samp_refactored(a, b):
a_sorted, b_sorted = \
map(np.sort, [a, b])
ab = \
np.concatenate([a_sorted, b_sorted])
cdf1, cdf2 = \
[np.searchsorted(x, ab, side='right') / float(len(x))
for x in [a_sorted, b_sorted]]
return np.max(np.absolute(cdf1 - cdf2))
def ks_2samp_refactored2(a, b):
a_sorted, b_sorted = map(np.sort, [a, b])
ab = np.concatenate([a_sorted, b_sorted])
cdf1, cdf2 = [np.searchsorted(x, ab, side='right') / float(len(x)) for x in [a_sorted, b_sorted]]
return np.max(np.absolute(cdf1 - cdf2))
a = [1,2,3]
b = [3,5,6]
print ks_2samp(a,b), ks_2samp_refactored(a,b)
assert ks_2samp(a,b) == ks_2samp_refactored(a,b)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment