Skip to content

Instantly share code, notes, and snippets.

@joshlk
Last active October 1, 2019 14:23
Show Gist options
  • Save joshlk/2626dbba1bd7bd47f9f2adb6b6d17be0 to your computer and use it in GitHub Desktop.
Save joshlk/2626dbba1bd7bd47f9f2adb6b6d17be0 to your computer and use it in GitHub Desktop.
Directed Agglomerative Clustering: similar to `sklearn.cluster.AgglomerativeClustering` but the label is the root node. The root node is the root of a connected DAG (Directed Acyclic Graph). Also can frame algorithm as determining the weakly connected components and identifying the root.
import numpy as np
from itertools import count, product
class DirectedAgglomerativeClustering:
"""
Similar to `sklearn.cluster.AgglomerativeClustering` but the label is the root node. The root node is the root of a
connected DAG (Directed Acyclic Graph).
Also can frame algorithm as determining the weakly connected components and identifying the root.
Algorithm is naive implementation and O(n^3)
"""
def __init__(self):
self.labels_ = None
self.n_clusters_ = None
def fit(self, X):
"""
Find clusters and root of those clusters.
:param X: Directed adjacency matrix which represents disconnected DAGs. 1 indicates i->j, 0 indicates no connection
:return:
"""
assert X.shape[0] == X.shape[1]
self.n_clusters_ = X.shape[0]
self.labels_ = np.arange(self.n_clusters_) # for each node what is the root node
for scan_i in count():
no_corrections = True
for root1, root2 in product(self.labels_, repeat=2):
if root1 == root2:
continue
rel = X[root1, root2]
if rel > 0:
self.labels_[self.labels_ == root1] = root2
no_corrections = False
self.n_clusters_ -= 1
if self.n_clusters_ == 1:
break
if no_corrections or self.n_clusters_ == 1:
break
return self
def test_compare_to_sklearn():
from sklearn.cluster import AgglomerativeClustering # Requires scikit-learn>=v0.21
adjacency_matrix = np.array(
# 0, 1, 2, 3, 4, 5
[[0, 0, 0, 0, 0, 0], # 0
[1, 0, 0, 0, 0, 0], # 1
[1, 1, 0, 0, 0, 0], # 2
[0, 0, 0, 0, 0, 0], # 3
[0, 0, 0, 1, 0, 0], # 4
[0, 0, 0, 0, 0, 0]] # 5
)
# 3 clusters. Root 0 inc. 1, 2. Root 3 inc. 4. Root 5 is on its own. Labels == [0, 0, 0, 3, 3, 5]
# Test sklearn
clf_sklearn = AgglomerativeClustering(
n_clusters=None, affinity='precomputed', linkage='single', distance_threshold=0.5)
distance_matrix = (~adjacency_matrix.astype('bool')).astype('float')
i_lower = np.tril_indices(len(distance_matrix), -1)
distance_matrix.T[i_lower] = distance_matrix[i_lower] # make the matrix symmetric
clf_sklearn.fit(distance_matrix)
assert clf_sklearn.n_clusters_ == 3
assert (clf_sklearn.labels_[:3] == clf_sklearn.labels_[0]).all() # Assert first 3 elements have same label
assert (clf_sklearn.labels_[3:5] == clf_sklearn.labels_[3]).all()
# Test Directed
clf_dir = DirectedAgglomerativeClustering()
clf_dir.fit(adjacency_matrix)
assert list(clf_dir.labels_) == [0, 0, 0, 3, 3, 5]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment