-
-
Save ysaito8015/e96a69a28d82c04f5466f5d9b1ff9539 to your computer and use it in GitHub Desktop.
Implementation of X-means clustering in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
以下の論文で提案された改良x-means法の実装 | |
クラスター数を自動決定するk-meansアルゴリズムの拡張について | |
http://www.rd.dnc.ac.jp/~tunenori/doc/xmeans_euc.pdf | |
""" | |
import numpy as np | |
from scipy import stats | |
from sklearn.cluster import KMeans | |
class XMeans: | |
""" | |
x-means法を行うクラス | |
""" | |
def __init__(self, k_init = 2, **k_means_args): | |
""" | |
k_init : The initial number of clusters applied to KMeans() | |
""" | |
self.k_init = k_init | |
self.k_means_args = k_means_args | |
def fit(self, X): | |
""" | |
x-means法を使ってデータXをクラスタリングする | |
X : array-like or sparse matrix, shape=(n_samples, n_features) | |
""" | |
self.__clusters = [] | |
clusters = self.Cluster.build(X, KMeans(self.k_init, **self.k_means_args).fit(X)) | |
self.__recursively_split(clusters) | |
self.labels_ = np.empty(X.shape[0], dtype = np.intp) | |
for i, c in enumerate(self.__clusters): | |
self.labels_[c.index] = i | |
self.cluster_centers_ = np.array([c.center for c in self.__clusters]) | |
self.cluster_log_likelihoods_ = np.array([c.log_likelihood() for c in self.__clusters]) | |
self.cluster_sizes_ = np.array([c.size for c in self.__clusters]) | |
return self | |
def __recursively_split(self, clusters): | |
""" | |
引数のclustersを再帰的に分割する | |
clusters : list-like object, which contains instances of 'XMeans.Cluster' | |
""" | |
for cluster in clusters: | |
if cluster.size <= 3: | |
self.__clusters.append(cluster) | |
continue | |
k_means = KMeans(2, **self.k_means_args).fit(cluster.data) | |
c1, c2 = self.Cluster.build(cluster.data, k_means, cluster.index) | |
beta = np.linalg.norm(c1.center - c2.center) / np.sqrt(np.linalg.det(c1.cov) + np.linalg.det(c2.cov)) | |
alpha = 0.5 / stats.norm.cdf(beta) | |
bic = -2 * (cluster.size * np.log(alpha) + c1.log_likelihood() + c2.log_likelihood()) + 2 * cluster.df * np.log(cluster.size) | |
if bic < cluster.bic(): | |
self.__recursively_split([c1, c2]) | |
else: | |
self.__clusters.append(cluster) | |
class Cluster: | |
""" | |
k-means法によって生成されたクラスタに関する情報を持ち、尤度やBICの計算を行うクラス | |
""" | |
@classmethod | |
def build(cls, X, k_means, index = None): | |
if index == None: | |
index = np.array(range(0, X.shape[0])) | |
labels = range(0, k_means.get_params()["n_clusters"]) | |
return tuple(cls(X, index, k_means, label) for label in labels) | |
# index: Xの各行におけるサンプルが元データの何行目のものかを示すベクトル | |
def __init__(self, X, index, k_means, label): | |
self.data = X[k_means.labels_ == label] | |
self.index = index[k_means.labels_ == label] | |
self.size = self.data.shape[0] | |
self.df = self.data.shape[1] * (self.data.shape[1] + 3) / 2 | |
self.center = k_means.cluster_centers_[label] | |
self.cov = np.cov(self.data.T) | |
def log_likelihood(self): | |
return sum(stats.multivariate_normal.logpdf(x, self.center, self.cov) for x in self.data) | |
def bic(self): | |
return -2 * self.log_likelihood() + self.df * np.log(self.size) | |
if __name__ == "__main__": | |
import matplotlib.pyplot as plt | |
# データの準備 | |
x = np.array([np.random.normal(loc, 0.1, 20) for loc in np.repeat([1,2], 2)]).flatten() | |
y = np.array([np.random.normal(loc, 0.1, 20) for loc in np.tile([1,2], 2)]).flatten() | |
# クラスタリングの実行 | |
x_means = XMeans(random_state = 1).fit(np.c_[x,y]) | |
print(x_means.labels_) | |
print(x_means.cluster_centers_) | |
print(x_means.cluster_log_likelihoods_) | |
print(x_means.cluster_sizes_) | |
# 結果をプロット | |
plt.rcParams["font.family"] = "Hiragino Kaku Gothic Pro" | |
plt.scatter(x, y, c = x_means.labels_, s = 30) | |
plt.scatter(x_means.cluster_centers_[:,0], x_means.cluster_centers_[:,1], c = "r", marker = "+", s = 100) | |
plt.xlim(0, 3) | |
plt.ylim(0, 3) | |
plt.title("改良x-means法の実行結果 参考: 石岡(2000)") | |
plt.show() | |
# plt.savefig("clustering.png", dpi = 200) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment