Created
September 15, 2012 14:36
-
-
Save weilinear/3728246 to your computer and use it in GitHub Desktop.
centers using sparse matrix
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _centers(X, labels, n_clusters, distances): | |
"""M step of the K-means EM algorithm | |
Computation of cluster centers / means. | |
Parameters | |
---------- | |
X: array, shape (n_samples, n_features) | |
labels: array of integers, shape (n_samples) | |
Current label assignment | |
n_clusters: int | |
Number of desired clusters | |
Returns | |
------- | |
centers: array, shape (n_clusters, n_features) | |
The resulting centers | |
""" | |
# TODO: add support for CSR input | |
n_features = X.shape[1] | |
# TODO: explicit dtype handling | |
empty_cluster = {}; | |
sparse_X = sp.issparse(X) | |
# label clone | |
reallocated_idx = 0 | |
# deal with empty centers | |
far_from_centers = None | |
for center_id in range(n_clusters): | |
center_mask = labels == center_id | |
if sparse_X: | |
center_mask = np.arange(len(labels))[center_mask] | |
if not np.any(center_mask): | |
# Reassign empty cluster center to sample far from any cluster | |
if far_from_centers is None: | |
far_from_centers = distances.argsort()[::-1] | |
empty_cluster[center_id] = far_from_centers[reallocated_idx] | |
reallocated_idx += 1 | |
# calculate indicator matrix | |
labels_unique, labels_normalized = np.unique(labels, return_inverse=True) | |
cluster_indicator = \ | |
sp.csr_matrix((np.ones(labels.shape[0] + len(empty_cluster)), | |
np.array([np.concatenate([labels, | |
np.array(empty_cluster.keys())]), | |
np.concatenate([np.arange(X.shape[0]), | |
np.array(empty_cluster.values())])])), | |
shape=(n_clusters, X.shape[0]), | |
dtype=np.float) | |
# normalize cluster_indicator | |
inplace_csr_row_normalize_l1(cluster_indicator) | |
# cluster_indicator /= cluster_indicator.sum(axis=1) | |
centers = safe_sparse_dot(cluster_indicator, X, dense_output=True) | |
return centers |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment