Created
May 12, 2019 08:41
-
-
Save tyliec/230132f050206baea41779676989a228 to your computer and use it in GitHub Desktop.
dmsk
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.datasets import make_blobs | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import random | |
import math | |
def plot_k_means(x, r, k, centers): | |
#random_colors = np.random.random((k, 3)) | |
#colors = r.dot(random_colors) | |
#colors = ('black') | |
print("Probabilities = ", r[:40]) | |
for center in centers: | |
print(center) | |
plt.plot(center[0], center[1],"ro") | |
plt.scatter(x[:,0], x[:,1]) | |
plt.show() | |
def initialize_centers(x, num_k): | |
N, D = x.shape | |
centers = np.zeros((num_k, D)) | |
used_idx = [] | |
for k in range(num_k): | |
idx = np.random.choice(N) | |
while idx in used_idx: | |
idx = np.random.choice(N) | |
used_idx.append(idx) | |
centers[k] = x[idx] | |
return centers | |
def update_centers(x, r, K): | |
N, D = x.shape | |
centers = np.zeros((K, D)) | |
for k in range(K): | |
centers[k] = r[:, k].dot(x) / r[:, k].sum() | |
return centers | |
def square_dist(a, b): | |
return (a - b) ** 2 | |
def distance(p0, p1): | |
return math.sqrt((p0[0] - p1[0])**2 + (p0[1] - p1[1])**2) | |
def cost_func(x, r, centers, K): | |
cost = 0 | |
for k in range(K): | |
norm = np.linalg.norm(x - centers[k], 2) | |
cost += (norm * np.expand_dims(r[:, k], axis=1) ).sum() | |
return cost | |
def cluster_responsibilities(centers, x, beta): | |
N, _ = x.shape | |
K, D = centers.shape | |
R = np.zeros((N, K)) | |
for n in range(N): | |
R[n] = np.exp(-beta * np.linalg.norm(centers - x[n], 2, axis=1)) | |
R /= R.sum(axis=1, keepdims=True) | |
return R | |
def soft_k_means(x, K, beta=1.): | |
centers = initialize_centers(x, K) | |
merge = False | |
print("centers after initialize: ",centers) | |
prev_cost = 0 | |
cost = 1 | |
distance_threshold = (1.0 / K) | |
cost_threshold = 1e-2 | |
while True: | |
#for _ in range(max_iters): | |
r = cluster_responsibilities(centers, x, beta) | |
centers = update_centers(x, r, K) | |
idx = 0 | |
if merge: | |
while idx < len(centers): | |
j = 0 | |
while j < len(centers): | |
dist = distance(centers[idx], centers[j]) | |
print 'Calculating Distance Between: ' + str(centers[idx]) + ' & ' + str(centers[j]) | |
print 'Distance: ' + str(dist) | |
if dist <= distance_threshold and idx != j: | |
print 'Removing!!' | |
centers[idx] = [(centers[idx][0] + centers[j][0]) / 2, (centers[idx][1] + centers[j][1]) / 2] | |
centers = np.delete(centers, j, 0) | |
idx = -1 | |
break | |
else: | |
j += 1 | |
idx += 1 | |
K = len(centers) | |
r = cluster_responsibilities(centers, x, beta) | |
centers = update_centers(x, r, K) | |
print("Centers after update are: ", centers) | |
cost = cost_func(x, r, centers, K) | |
print("distance the centers moved= ", np.abs(cost - prev_cost)) | |
if np.abs(cost - prev_cost) < cost_threshold: | |
print 'Breaking!! Final K:' + str(K) | |
break | |
prev_cost = cost | |
K *= 2 | |
split_centers = np.zeros((K, x.shape[1])) | |
print ("new centers: ", split_centers) | |
for i, center in enumerate(centers): | |
split_centers[i * 2] = center | |
split_centers[(i * 2) + 1] = [center[0] + (random.randint(0, 100) * 0.01), center[1] + (random.randint(0, 100) * 0.01)] | |
print("new center = ", split_centers) | |
centers = split_centers | |
merge = True | |
print("Final centers before plot= ", centers) | |
plot_k_means(x, r, K, centers) | |
def generate_samples(std=.5, dim=2, dist=4): | |
x, B2 = make_blobs(n_samples=100, centers=4, cluster_std=.1, random_state=12) | |
#mu0 = np.array([0,0]) | |
#mu1 = np.array([dist, dist]) | |
#mu2 = np.array([0, dist]) | |
# num samps per class | |
#Nc = 10 | |
#x0 = np.random.randn(Nc, dim) * std + mu0 | |
#x1 = np.random.randn(Nc, dim) * std + mu1 | |
#x2 = np.random.randn(Nc, dim) * std + mu2 | |
#x = np.concatenate((x0, x1, x2), axis=0) | |
return x | |
def main(): | |
x = generate_samples() | |
K=2 | |
soft_k_means(x, K) | |
#K = 2 | |
#while K < 10: | |
#soft_k_means(x, K) | |
#K+=1 | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment