Last active
October 1, 2019 14:23
-
-
Save joshlk/2626dbba1bd7bd47f9f2adb6b6d17be0 to your computer and use it in GitHub Desktop.
Directed Agglomerative Clustering: similar to `sklearn.cluster.AgglomerativeClustering` but the label is the root node. The root node is the root of a connected DAG (Directed Acyclic Graph). Also can frame algorithm as determining the weakly connected components and identifying the root.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from itertools import count, product | |
class DirectedAgglomerativeClustering: | |
""" | |
Similar to `sklearn.cluster.AgglomerativeClustering` but the label is the root node. The root node is the root of a | |
connected DAG (Directed Acyclic Graph). | |
Also can frame algorithm as determining the weakly connected components and identifying the root. | |
Algorithm is naive implementation and O(n^3) | |
""" | |
def __init__(self): | |
self.labels_ = None | |
self.n_clusters_ = None | |
def fit(self, X): | |
""" | |
Find clusters and root of those clusters. | |
:param X: Directed adjacency matrix which represents disconnected DAGs. 1 indicates i->j, 0 indicates no connection | |
:return: | |
""" | |
assert X.shape[0] == X.shape[1] | |
self.n_clusters_ = X.shape[0] | |
self.labels_ = np.arange(self.n_clusters_) # for each node what is the root node | |
for scan_i in count(): | |
no_corrections = True | |
for root1, root2 in product(self.labels_, repeat=2): | |
if root1 == root2: | |
continue | |
rel = X[root1, root2] | |
if rel > 0: | |
self.labels_[self.labels_ == root1] = root2 | |
no_corrections = False | |
self.n_clusters_ -= 1 | |
if self.n_clusters_ == 1: | |
break | |
if no_corrections or self.n_clusters_ == 1: | |
break | |
return self | |
def test_compare_to_sklearn(): | |
from sklearn.cluster import AgglomerativeClustering # Requires scikit-learn>=v0.21 | |
adjacency_matrix = np.array( | |
# 0, 1, 2, 3, 4, 5 | |
[[0, 0, 0, 0, 0, 0], # 0 | |
[1, 0, 0, 0, 0, 0], # 1 | |
[1, 1, 0, 0, 0, 0], # 2 | |
[0, 0, 0, 0, 0, 0], # 3 | |
[0, 0, 0, 1, 0, 0], # 4 | |
[0, 0, 0, 0, 0, 0]] # 5 | |
) | |
# 3 clusters. Root 0 inc. 1, 2. Root 3 inc. 4. Root 5 is on its own. Labels == [0, 0, 0, 3, 3, 5] | |
# Test sklearn | |
clf_sklearn = AgglomerativeClustering( | |
n_clusters=None, affinity='precomputed', linkage='single', distance_threshold=0.5) | |
distance_matrix = (~adjacency_matrix.astype('bool')).astype('float') | |
i_lower = np.tril_indices(len(distance_matrix), -1) | |
distance_matrix.T[i_lower] = distance_matrix[i_lower] # make the matrix symmetric | |
clf_sklearn.fit(distance_matrix) | |
assert clf_sklearn.n_clusters_ == 3 | |
assert (clf_sklearn.labels_[:3] == clf_sklearn.labels_[0]).all() # Assert first 3 elements have same label | |
assert (clf_sklearn.labels_[3:5] == clf_sklearn.labels_[3]).all() | |
# Test Directed | |
clf_dir = DirectedAgglomerativeClustering() | |
clf_dir.fit(adjacency_matrix) | |
assert list(clf_dir.labels_) == [0, 0, 0, 3, 3, 5] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment