Skip to content

Instantly share code, notes, and snippets.

@altescy
Last active December 25, 2022 13:34
Show Gist options
  • Select an option

  • Save altescy/8465dbb030b6d97ac9a9296ca60c8e9f to your computer and use it in GitHub Desktop.

Select an option

Save altescy/8465dbb030b6d97ac9a9296ca60c8e9f to your computer and use it in GitHub Desktop.
from typing import TypeVar
import numpy
from sklearn.base import BaseEstimator, ClusterMixin
from sklearn.cluster import KMeans
Self = TypeVar("Self", bound="HierarchicalKmeans")
class HierarchicalKmeans(BaseEstimator, ClusterMixin):
def __init__(
self,
n_clusters_per_level: int = 8,
max_leaf_nodes: int = 8,
) -> None:
super().__init__()
self.n_clusters_per_level = n_clusters_per_level
self.max_leaf_nodes = max_leaf_nodes
self.kmeans: dict[str, KMeans] = {}
def fit_predict(
self: Self,
X: numpy.ndarray,
y: numpy.ndarray = None,
) -> list[list[int]]:
def fit_predict_recursively(
level: str,
inputs: numpy.ndarray,
predictions: list[list[int]],
) -> list[list[int]]:
self.kmeans[level] = KMeans(
n_clusters=self.n_clusters_per_level,
n_init="auto",
)
clusters = self.kmeans[level].fit_predict(inputs)
for index in range(len(inputs)):
predictions[index].append(clusters[index])
for cluster_id in set(clusters):
inputs_cluster = inputs[clusters == cluster_id]
predictions_cluster = [
ids
for ids, cluster in zip(predictions, clusters)
if cluster == cluster_id
]
if len(inputs_cluster) > self.max_leaf_nodes:
fit_predict_recursively(
level=f"{level}.{cluster_id}",
inputs=inputs_cluster,
predictions=predictions_cluster,
)
else:
for index in range(len(inputs_cluster)):
predictions_cluster[index].append(index)
return predictions
level = "root"
predictions: list[list[int]] = [[] for _ in range(len(X))]
return fit_predict_recursively(level, X, predictions)
if __name__ == "__main__":
X = numpy.random.uniform(size=(128, 16))
hierarchical_kmeans = HierarchicalKmeans(
n_clusters_per_level=4,
max_leaf_nodes=4,
)
predictions = hierarchical_kmeans.fit_predict(X)
print(predictions)
print("num of examples :", len(X))
print("num of unique ids:", len(set(tuple(ids) for ids in predictions)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment