Last active
May 7, 2020 03:44
-
-
Save ground0state/0e7e671ae0da76417360a166276cf425 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
MIT License | |
Copyright (c) 2017-2020 Packt, grouns0state | |
https://github.com/PacktPublishing/Artificial-Intelligence-with-Python/blob/master/LICENSE | |
""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.cluster import MeanShift, estimate_bandwidth | |
from itertools import cycle | |
from sklearn.datasets import make_blobs | |
X, _ = make_blobs(n_samples=1000, centers=3, n_features=5, random_state=0) | |
bandwidth_X = estimate_bandwidth(X, quantile=0.1, n_samples=len(X)) | |
meanshift_model = MeanShift(bandwidth=bandwidth_X, bin_seeding=True) | |
meanshift_model.fit(X) | |
cluster_centers = meanshift_model.cluster_centers_ | |
print('Centers of clusters:\n', cluster_centers) | |
labels = meanshift_model.labels_ | |
num_clusters = len(np.unique(labels)) | |
print("\nNumber of clusters in input data =", num_clusters) | |
plt.figure() | |
markers = 'o*xvs' | |
for i, marker in zip(range(num_clusters), markers): | |
plt.scatter(X[labels == i, 0], X[labels == i, 1], | |
marker=marker, color='black') | |
cluster_center = cluster_centers[i] | |
plt.plot(cluster_center[0], cluster_center[1], marker='o', | |
markerfacecolor='black', markeredgecolor='black', | |
markersize=15) | |
plt.title('Clusters') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment