Skip to content

Instantly share code, notes, and snippets.

@thomasahle
Created October 21, 2022 20:53
Show Gist options
  • Save thomasahle/4f16b19aa395f25e8fee882e3a82a4d9 to your computer and use it in GitHub Desktop.
Save thomasahle/4f16b19aa395f25e8fee882e3a82a4d9 to your computer and use it in GitHub Desktop.
Various methods of finding semi-sparse K-means clusterings
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import euclidean_distances
import collections
def kl_means(X, k:int, l:int, policy:str):
n, d = X.shape
km = KMeans(k).fit(X)
centers = km.cluster_centers_ / l
labels = np.stack([km.labels_] * l).T
old_loss = 10**10
t = 0
while True:
t += 1
if policy == 'top':
dists = euclidean_distances(X, centers, squared=True)
new_labels = np.argsort(dists)[:, :l]
elif policy == 'greedy':
ls = []
res = np.copy(X)
for _ in range(l):
dists = euclidean_distances(res, centers, squared=True)
best = np.argmin(dists, axis=1)
res -= centers[best]
ls.append(best)
new_labels = np.stack(ls).T
elif policy == 'greedy2':
ls = []
Y = sum(centers[c] for c in labels.T)
for i in range(l):
Y -= centers[labels[:, i]]
dists = euclidean_distances(X-Y, centers, squared=True)
nl = np.argmin(dists, axis=1)
ls.append(nl)
Y += centers[nl]
new_labels = np.stack(ls).T
elif policy == 'rvq':
Y = sum(centers[c] for c in labels.T)
new_labels = np.copy(labels)
shift = 0
for i in range(l):
# Remove the current part from Y so it becomes just
# the parts we are not currently editing
Y -= centers[labels[:, i]]
residual = X - Y
ki = (k + i) // l
km = KMeans(ki).fit(residual)
new_labels[:, i] = shift + km.labels_
shift += ki
# Add the current part back in, so Y is again the
# current approximation to X
Y += km.cluster_centers_[km.labels_]
else:
assert False
labels = new_labels
H = np.zeros((n, k))
for col in labels.T:
H[np.arange(n), col] += 1
M, _res, _rank, _s = np.linalg.lstsq(H, X, rcond=None)
centers = M
loss = np.linalg.norm(X - sum(centers[c] for c in labels.T))
if old_loss - loss < 1e-8:
break
else:
old_loss = loss
return labels, centers, loss, t
n, d = 500, 10
k = 20
maxl = k-1
reps = 3
data = collections.defaultdict(list)
ls = range(1, maxl+1)
for _ in range(reps):
X = np.random.randn(n, d)
km = KMeans(k).fit(X)
loss = np.linalg.norm(X - km.cluster_centers_[km.labels_])
print('KMeans baseline:', loss)
for l in ls:
print('L:', l)
for p in ['top', 'greedy', 'rvq']:
labels, centers, loss, t =\
kl_means(X, k, l, p)
print(f'{p}, {t}its, {loss:.3f}')
data[p].append(loss)
import matplotlib.pyplot as plt
for label, series in data.items():
ar = np.array(series).reshape(reps, -1)
low, hi, mean = ar.min(axis=0), ar.max(axis=0), ar.mean(axis=0)
plt.fill_between(ls, low, hi, alpha=.1)
plt.plot(ls, mean, label=label)
plt.legend()
plt.title(f'L-dense {k}-means on ({n},{d}) randn')
plt.xlabel('L')
plt.ylabel('RMSE')
plt.show()
@aribenjamin
Copy link

aribenjamin commented Oct 27, 2022

In the RVQ method, I wondered if the full call to Kmeans in the inner loop was wasteful. Why not just do a single half-Kmeans update of the labels, on the residuals? Like this:

        elif policy == 'rvq_partial':
            Y = sum(centers[c] for c in labels.T)
            new_labels = np.copy(labels)
            # get the start and end of the active clusters. If k%l != 0, evenly distribute the remainder instead of tacking it on the end.
            d = np.array(sum([[i//(k/l)] for i in range(k)],[]))
            changepoints = np.concatenate(([0], np.argwhere(np.diff(d)).flatten(), [k]))
            for i in range(l):
                # Remove the current part from Y so it becomes just
                # the parts we are not currently editing
                Y -= centers[labels[:, i]]
                shift = changepoints[i]

                residual = X - Y
                # just do one half iteration of kmeans
                dists = euclidean_distances(residual, centers[shift:changepoints[i+1]], squared=True)
                new_labels[:, i] = shift + np.argmin(dists, axis=1)

                # Make Y again the
                # current approximation to X
                Y += centers[new_labels[:,i]]
                ```
I found this is about 50x faster. The performance seems almost unchanged: 
![image](https://user-images.githubusercontent.com/9724569/198394790-90716f4c-34ab-4ace-a714-e8addf27052e.png)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment