Skip to content

Instantly share code, notes, and snippets.

@kohnakagawa
Last active November 22, 2018 07:02
Show Gist options
  • Save kohnakagawa/be0b637c1c85862f14c0e492b7a28c11 to your computer and use it in GitHub Desktop.
Save kohnakagawa/be0b637c1c85862f14c0e492b7a28c11 to your computer and use it in GitHub Desktop.
クラスタリングの結果をneo4jで可視化するためのスクリプト
import pandas as pd
import numpy as np
from random import randint
from numpy import linalg as LA
from itertools import combinations
from sklearn import datasets
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import homogeneity_score
from neomodel import StructuredNode, StructuredRel
from neomodel import StringProperty, FloatProperty
from neomodel import Relationship
from neomodel import config, db
# NOTE: パスワードは適切に変更すること
config.DATABASE_URL = 'bolt://neo4j:hogehoge@localhost:7687'
class FlowerRel(StructuredRel):
similarity = FloatProperty()
def make_class_instance(cls_name, super_class, attribs):
return type(cls_name, super_class, attribs)()
def make_flower(name_):
if name_ == 0:
label_name = "hoge"
elif name_ == 1:
label_name = "fuga"
elif name_ == 2:
label_name = "hoga"
else:
label_name = "tsune"
obj = make_class_instance(label_name,
(StructuredNode,),
{"name": StringProperty(unique_index=False),
"flowers": Relationship('StructuredNode', 'SPEC', model=FlowerRel)})
obj.name = label_name
return obj
def get_iris_data():
iris = datasets.load_iris()
iris_data = pd.DataFrame(data=iris.data, columns=iris.feature_names)
iris_label = pd.Series(data=iris.target)
num_class = len(set(iris_label.values))
return iris_data.values, iris_label.values, num_class
def clear_node():
db.cypher_query('MATCH (n) DETACH DELETE n')
def vizualize(X, y, y_pred, nclass):
clear_node()
for c in range(nclass):
target_idx = y_pred == c
X_target = X[target_idx]
y_target = y[target_idx]
flowers = [(X, make_flower(y).save()) for X, y in zip(X_target, y_target)]
for Xy0, Xy1 in combinations(flowers, 2):
X0, flower0 = Xy0
X1, flower1 = Xy1
dX01 = X0 - X1
dX01_2 = dX01 * dX01
dist = np.sqrt(dX01_2.sum())
# NOTE: connectionをすべて描画すると可視化が重くなるため、一部のみ表示
if randint(0, 1000) < 200:
flower0.flowers.connect(flower1, {'similarity': dist})
def main():
X, y, nclass = get_iris_data()
km = KMeans(n_clusters=nclass)
km.fit(X)
y_pred = km.labels_
vizualize(X, y, y_pred, nclass)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment