Skip to content

Instantly share code, notes, and snippets.

@nguigs
Created December 2, 2022 10:50
Show Gist options
  • Save nguigs/92d290c99439495383ffa701dbf681db to your computer and use it in GitHub Desktop.
Save nguigs/92d290c99439495383ffa701dbf681db to your computer and use it in GitHub Desktop.
Toy example for online kmeans on the Grassmann manifold
"""Riemannian k-means on the Grassmann Manifold.
Nicolas Guigui, 29/11/2022
"""
import geomstats.backend as gs
from geomstats.geometry.grassmannian import Grassmannian
from geomstats.learning.online_kmeans import OnlineKMeans
n_samples = 10
n = 3
p = 1
manifold = Grassmannian(n, p)
metric = manifold.metric
# Generate data around first random center
center_1 = manifold.random_point()
vec_1 = manifold.random_tangent_vec(center_1, n_samples)
vec_1 /= metric.injectivity_radius(center_1) * 10
cluster_1 = metric.exp(vec_1, center_1)
print(manifold.belongs(cluster_1).all())
# Generate data around second random center
center_2 = manifold.random_point()
vec_2 = manifold.random_tangent_vec(center_2, n_samples)
vec_2 /= metric.injectivity_radius(center_2) * 10
cluster_2 = metric.exp(vec_2, center_2)
print(manifold.belongs(cluster_2).all())
data = gs.concatenate((cluster_1, cluster_2), axis=0)
kmeans = OnlineKMeans(metric, n_clusters=2, max_iter=2000, atol=1e-3)
kmeans.fit(data)
labels = kmeans.predict(data)
centroids = kmeans.cluster_centers_
print(manifold.belongs(centroids))
print(metric.dist(center_1, centroids))
print(metric.dist(center_2, centroids))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment