-
-
Save michiexile/5635273 to your computer and use it in GitHub Desktop.
# gap.py | |
# (c) 2013 Mikael Vejdemo-Johansson | |
# BSD License | |
# | |
# SciPy function to compute the gap statistic for evaluating k-means clustering. | |
# Gap statistic defined in | |
# Tibshirani, Walther, Hastie: | |
# Estimating the number of clusters in a data set via the gap statistic | |
# J. R. Statist. Soc. B (2001) 63, Part 2, pp 411-423 | |
import scipy | |
import scipy.cluster.vq | |
import scipy.spatial.distance | |
dst = scipy.spatial.distance.euclidean | |
def gap(data, refs=None, nrefs=20, ks=range(1,11)): | |
""" | |
Compute the Gap statistic for an nxm dataset in data. | |
Either give a precomputed set of reference distributions in refs as an (n,m,k) scipy array, | |
or state the number k of reference distributions in nrefs for automatic generation with a | |
uniformed distribution within the bounding box of data. | |
Give the list of k-values for which you want to compute the statistic in ks. | |
""" | |
shape = data.shape | |
if refs==None: | |
tops = data.max(axis=0) | |
bots = data.min(axis=0) | |
dists = scipy.matrix(scipy.diag(tops-bots)) | |
rands = scipy.random.random_sample(size=(shape[0],shape[1],nrefs)) | |
for i in range(nrefs): | |
rands[:,:,i] = rands[:,:,i]*dists+bots | |
else: | |
rands = refs | |
gaps = scipy.zeros((len(ks),)) | |
for (i,k) in enumerate(ks): | |
(kmc,kml) = scipy.cluster.vq.kmeans2(data, k) | |
disp = sum([dst(data[m,:],kmc[kml[m],:]) for m in range(shape[0])]) | |
refdisps = scipy.zeros((rands.shape[2],)) | |
for j in range(rands.shape[2]): | |
(kmc,kml) = scipy.cluster.vq.kmeans2(rands[:,:,j], k) | |
refdisps[j] = sum([dst(rands[m,:,j],kmc[kml[m],:]) for m in range(shape[0])]) | |
gaps[i] = scipy.log(scipy.mean(refdisps))-scipy.log(disp) | |
return gaps | |
I believe the ordering of log and mean in line 48 should be flipped. Should be:
gaps[i] = scipy.mean(scipy.log(refdisps))-scipy.log(disp)
Thanks, I agree with catfishy, that the logarithms need to be calculated firstly.
Thanks! The implementation and comments here helps me a lot!
catfishy and faithefeng are correct. Please see equation 3 in Tibshirani's paper. You will not necessarily get the same result if you leave the code as is.
What's the purpose of nrefs?
Can someone suggest me how to interpret the return value of the function ie. gaps
Can someone suggest me how to interpret the return value of the function ie. gaps
Hello! It returns list of "gaps" between reference data sets and given - > number of clusters is number from ks that match maximum return value :)
The code is very complex,can you make it easier to understand whats happening inside.There is very less content on Gap Statistics
Can someone suggest me how to interpret the return value of the function ie. gaps
Hello! It returns list of "gaps" between reference data sets and given - > number of clusters is number from ks that match maximum return value :)
This is inaccurate. If you are googling around, this is not enough. You also need to compute the standard deviation over the bootstrap and use it in the final step.
Hi!
I needed to change to the sklearn's Kmeans implementarion (http://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html) to avoid getting some "Matrix is not positive definite" errors with some data. Other than that it worked perfectly.
Thanks!