Skip to content

Instantly share code, notes, and snippets.

@Killavus
Created January 29, 2016 00:25
Show Gist options
  • Save Killavus/819c8afb0c25ac051945 to your computer and use it in GitHub Desktop.
Save Killavus/819c8afb0c25ac051945 to your computer and use it in GitHub Desktop.
from fuel.datasets.cifar10 import CIFAR10
from sklearn.cluster import KMeans
from collections import Counter
from scipy.stats.mstats import mode
import numpy as np
def PerformKMeansAnalysis(train_data, train_labels, test_data, test_labels):
naive_grouping = KMeans(n_clusters=10, n_init=5, n_jobs=8)
naive_grouping.fit(train_data)
group_labels = naive_grouping.labels_
for label in set(group_labels):
train_mask = group_labels == label
real_labels = train_labels[train_mask]
pos_stat = Counter(real_labels).values()
pos_stat = np.array(pos_stat).astype(np.float)
pos_stat /= real_labels.shape[0]
representative = mode(real_labels)[0]
gini_index = np.dot(pos_stat.T, np.ones(pos_stat.shape[0]) - pos_stat)
print "For label %d: %s" % (label, gini_index)
print "Representative: ", representative
print "======"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment