Created
June 24, 2020 15:51
-
-
Save keuv-grvl/67d125166386c81769cfee7be791b178 to your computer and use it in GitHub Desktop.
G-Means
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import seaborn as sbn | |
from matplotlib import pyplot as plt | |
from scipy.stats import anderson | |
from sklearn import datasets | |
from sklearn.cluster import MiniBatchKMeans | |
from sklearn.preprocessing import scale, LabelEncoder | |
# TODO doc, memoization | |
class GMeans(object): | |
"""The G-means recusrive clustering algorithm. | |
from:https://github.com/flylo/g-means/blob/master/gmeans.py | |
strictness = how strict should the anderson-darling test for normality be | |
0: not at all strict | |
4: very strict | |
""" | |
def __init__(self, min_obs=1, max_depth=10, random_state=None, strictness=4): | |
if strictness not in range(5): | |
raise ValueError("strictness parameter must be integer from 0 to 4") | |
self._le = LabelEncoder() | |
self._current_id = -2 # quickfix to start at 0 | |
self.max_depth = max_depth | |
self.min_obs = min_obs | |
self.random_state = random_state | |
self.strictness = strictness | |
self.stopping_criteria = [] | |
@property | |
def next_id(self): | |
self._current_id += 1 | |
return self._current_id | |
def _gaussianCheck(self, vector): | |
""" | |
check whether a given input vector follows a gaussian distribution | |
H0: vector is distributed gaussian | |
H1: vector is not distributed gaussian | |
""" | |
output = anderson(vector) | |
return output.statistic <= output.critical_values[self.strictness] | |
def _recursiveClustering(self, data, depth, index): | |
""" | |
recursively run kmeans with k=2 on your data until a max_depth is reached or we have | |
gaussian clusters | |
""" | |
depth += 1 | |
if depth == self.max_depth: | |
self.data_index[index[:, 0]] = index | |
self.stopping_criteria.append("max_depth") | |
return | |
km = MiniBatchKMeans(n_clusters=2, random_state=self.random_state) | |
km.fit(data) | |
centers = km.cluster_centers_ | |
v = centers[0] - centers[1] | |
x_prime = scale(data.dot(v) / (v.dot(v))) | |
gaussian = self._gaussianCheck(x_prime) | |
# print gaussian | |
if gaussian: | |
self.data_index[index[:, 0]] = index | |
self.stopping_criteria.append("gaussian") | |
return | |
labels = set(km.labels_) | |
for k in labels: | |
current_data = data[km.labels_ == k] | |
if current_data.shape[0] <= self.min_obs: | |
self.data_index[index[:, 0]] = index | |
self.stopping_criteria.append("min_obs") | |
return | |
current_index = index[km.labels_ == k] | |
current_index[:, 1] = self.next_id | |
self._recursiveClustering( | |
data=current_data, depth=depth, index=current_index | |
) | |
def fit(self, data): | |
""" | |
fit the recursive clustering model to the data | |
""" | |
self.data = data # FIXME potential massive memory usage? | |
data_index = np.array([(i, False) for i in range(data.shape[0])]) | |
self.data_index = data_index | |
self._recursiveClustering(data=data, depth=0, index=data_index) | |
# self.labels_ = self._le.fit_transform(self.data_index[:, 1]) | |
self.labels_ = self.data_index[:, 1] | |
if __name__ == "__main__": | |
iris, true_label = datasets.make_blobs( | |
n_samples=5000, n_features=2, centers=6, cluster_std=0.6 | |
) | |
gmeans = GMeans(min_obs=1) | |
gmeans.fit(iris) | |
plot_data = pd.DataFrame(iris[:, 0:2]) | |
plot_data.columns = ["x", "y"] | |
plot_data["labels_true"] = true_label | |
plot_data["labels_gmeans"] = gmeans.labels_ | |
km = MiniBatchKMeans(n_clusters=4) | |
km.fit(iris) | |
plot_data["labels_km"] = km.labels_ | |
# plot results | |
sbn.lmplot(x="x", y="y", data=plot_data, hue="labels_true", fit_reg=False) | |
sbn.lmplot(x="x", y="y", data=plot_data, hue="labels_gmeans", fit_reg=False) | |
sbn.lmplot(x="x", y="y", data=plot_data, hue="labels_km", fit_reg=False) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment