Last active
January 7, 2016 03:13
-
-
Save dela3499/a2a76c61dc5b86062a02 to your computer and use it in GitHub Desktop.
Refactoring the Kolmogorov-Smirnov Test implementation in Scipy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Wikipedia: https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test#Two-sample_Kolmogorov.E2.80.93Smirnov_test | |
# Scipy source: https://github.com/scipy/scipy/blob/v0.15.1/scipy/stats/stats.py#L3966 | |
def ks_2samp(data1, data2): | |
data1, data2 = map(asarray, (data1, data2)) | |
n1 = data1.shape[0] # n1 and n2 reassigned below, and can be removed here. | |
n2 = data2.shape[0] # maybe this is used for error message, in case inputs aren't lists. But len provides a good error message. | |
n1 = len(data1) # Avoid repetition with list comprehension or map | |
n2 = len(data2) | |
data1 = np.sort(data1) # Avoid repetition with list comprehension or map | |
data2 = np.sort(data2) # and np.sort turns lists into arrays (so asarray is unnecessary) | |
# Also, don't mutate input variable. Give it a new name. | |
data_all = np.concatenate([data1,data2]) | |
cdf1 = np.searchsorted(data1,data_all,side='right')/(1.0*n1) # Avoid repetition. To coerce length to float, use float function, not multiplication by a float | |
cdf2 = (np.searchsorted(data2,data_all,side='right'))/(1.0*n2) | |
d = np.max(np.absolute(cdf1-cdf2)) | |
return d | |
def ks_2samp_refactored(a, b): | |
a_sorted, b_sorted = \ | |
map(np.sort, [a, b]) | |
ab = \ | |
np.concatenate([a_sorted, b_sorted]) | |
cdf1, cdf2 = \ | |
[np.searchsorted(x, ab, side='right') / float(len(x)) | |
for x in [a_sorted, b_sorted]] | |
return np.max(np.absolute(cdf1 - cdf2)) | |
def ks_2samp_refactored2(a, b): | |
a_sorted, b_sorted = map(np.sort, [a, b]) | |
ab = np.concatenate([a_sorted, b_sorted]) | |
cdf1, cdf2 = [np.searchsorted(x, ab, side='right') / float(len(x)) for x in [a_sorted, b_sorted]] | |
return np.max(np.absolute(cdf1 - cdf2)) | |
a = [1,2,3] | |
b = [3,5,6] | |
print ks_2samp(a,b), ks_2samp_refactored(a,b) | |
assert ks_2samp(a,b) == ks_2samp_refactored(a,b) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment