Created
September 10, 2015 22:28
-
-
Save warmlogic/bb9810f7a0dc350297a4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# calculate and visualize silhouette score from k-means clustering. | |
# plots first two features in 2D and first three features in 3D. | |
# from: http://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_silhouette_analysis.html | |
from sklearn.cluster import KMeans | |
from sklearn.metrics import silhouette_score, silhouette_samples | |
from mpl_toolkits.mplot3d import Axes3D | |
##### cluster data into K=1..K_MAX clusters ##### | |
K_MAX = 10 | |
KK = range(1,K_MAX+1) | |
KM = [] | |
for k in KK: | |
thisKM = KMeans(n_clusters=k, init='k-means++', n_init=10, n_jobs=10) | |
thisKM.fit(trans[:,:n_components]) | |
KM.append(thisKM) | |
# choose a random subset to visualize | |
n=10000 | |
# n=7500 | |
# n=5000 | |
sil_rows = np.random.choice(range(trans.shape[0]), n, replace=False) | |
min_n_clus = 2 | |
max_n_clus = K_MAX | |
silhouettes = [] | |
X = trans[sil_rows,:n_components] | |
for k in KM: | |
n_clusters = k.n_clusters | |
cluster_labels = k.labels_[sil_rows] | |
if (n_clusters >= min_n_clus) & (n_clusters <= max_n_clus): | |
sample_silhouette_values = silhouette_samples(X, cluster_labels) | |
# silhouette_avg = silhouette_score(X, cluster_labels, metric='euclidean') | |
silhouette_avg = np.mean(sample_silhouette_values) | |
print 'k=%d, score=%.5f' % (n_clusters, silhouette_avg) | |
silhouettes.append(silhouette_avg) | |
fig = plt.figure(figsize=(14,5)) | |
ax1 = fig.add_subplot(1, 3, 1) | |
ax2 = fig.add_subplot(1, 3, 2) | |
ax3 = fig.add_subplot(1, 3, 3, projection='3d') | |
y_lower = 10 | |
for i in range(n_clusters): | |
# Aggregate the silhouette scores for samples belonging to | |
# cluster i, and sort them | |
ith_cluster_silhouette_values = sample_silhouette_values[cluster_labels == i] | |
ith_cluster_silhouette_values.sort() | |
size_cluster_i = ith_cluster_silhouette_values.shape[0] | |
y_upper = y_lower + size_cluster_i | |
color = plt.cm.spectral(float(i) / n_clusters) | |
ax1.fill_betweenx(np.arange(y_lower, y_upper), | |
0, ith_cluster_silhouette_values, | |
facecolor=color, edgecolor=color, alpha=0.7) | |
# Label the silhouette plots with their cluster numbers at the middle | |
ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i)) | |
# Compute the new y_lower for next plot | |
y_lower = y_upper + 10 # 10 for the 0 samples | |
ax1.set_title("Silhouette plot") | |
ax1.set_xlabel("Silhouette coefficient values") | |
ax1.set_ylabel("Cluster label") | |
# The vertical line for average silhoutte score of all the values | |
ax1.axvline(x=silhouette_avg, color="red", linestyle="--") | |
ax1.set_yticks([]) # Clear the yaxis labels / ticks | |
ax1.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1]) | |
sil_colors = plt.cm.spectral(cluster_labels.astype(float) / n_clusters) | |
# 2nd Plot showing the actual clusters formed in 2D space | |
ax2.scatter(X[:, 0], X[:, 1], marker='.', s=30, lw=0, alpha=0.7, c=sil_colors) | |
# Labeling the clusters | |
centers = k.cluster_centers_ | |
# Draw white circles at cluster centers | |
ax2.scatter(centers[:, 0], centers[:, 1], marker='o', c="white", alpha=1, s=200) | |
for i, c in enumerate(centers): | |
ax2.scatter(c[0], c[1], marker='$%d$' % i, alpha=1, s=50) | |
ax2.set_title("2D visualization of clustered data") | |
ax2.set_xlabel("Feature space, 1st feature") | |
ax2.set_ylabel("Feature space, 2nd feature") | |
# 3rd Plot showing the actual clusters formed in 3D space | |
ax3.scatter(X[:, 0], X[:, 1], X[:, 2], marker='.', s=30, lw=0, alpha=0.7, c=sil_colors); | |
# # Labeling the clusters | |
# centers = k.cluster_centers_ | |
# # Draw white circles at cluster centers | |
# ax3.scatter(centers[:, 0], centers[:, 1], centers[:, 2], marker='o', c="white", alpha=1, s=200) | |
# for i, c in enumerate(centers): | |
# ax3.scatter(c[0], c[1], c[2], marker='$%d$' % i, alpha=1, s=50) | |
ax3.set_title("3D visualization of clustered data") | |
ax3.set_xlabel("Feature space, 1st feature") | |
ax3.set_ylabel("Feature space, 2nd feature") | |
ax3.set_zlabel("Feature space, 3nd feature") | |
plt.suptitle(("Silhouette score on sample data=%.4f" | |
" for KMeans clustering (k=%d)" | |
", kept %d PCs" | |
% (silhouette_avg, n_clusters, n_components)), | |
fontsize=14, fontweight='bold') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment