Created
July 22, 2016 07:47
-
-
Save hristo-vrigazov/747780cc1afa3cd2bf39d4e14b4eaeec to your computer and use it in GitHub Desktop.
Simple k means implementation to solve Coursera's quiz
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import statistics | |
import copy | |
centroids = [(2, 2), (-2, -2)] | |
data_points = [ \ | |
(-1.88, 2.05), \ | |
(-0.71, 0.42), \ | |
(2.41, -0.67), \ | |
(1.85, -3.80), \ | |
(-3.69, -1.33)] | |
belongs_to = [0] * len(data_points) | |
# counts how many times a point has changed to answer | |
# Coursera's quiz question | |
changed = [0] * len(data_points) | |
def distance(data_point, centroid): | |
return math.sqrt(sum([(data_point[i] - centroid[i]) ** 2 for i in range(len(data_point))])) | |
def find_closest_centroid_index_to_data_point(i): | |
min_distance = float("+inf") | |
min_index = 0 | |
for j in range(len(centroids)): | |
dist = distance(data_points[i], centroids[j]) | |
if dist < min_distance: | |
min_distance = dist | |
min_index = j | |
return min_index | |
def assign_data_point_to_nearest_centroid(i): | |
closest_centroid_index = find_closest_centroid_index_to_data_point(i) | |
if belongs_to[i] != closest_centroid_index: | |
changed[i] += 1 | |
belongs_to[i] = closest_centroid_index | |
def assign_data_points_to_nearest_centroid(): | |
for i in range(len(data_points)): | |
assign_data_point_to_nearest_centroid(i) | |
converged = False | |
def adjust_centers(): | |
before_update_centroids = copy.deepcopy(centroids) | |
for i in range(len(centroids)): | |
data_points_in_this_centroid = [data_points[j] for j in range(len(data_points)) if belongs_to[j] == i] | |
x_mean = statistics.mean(map(lambda x: x[0], data_points_in_this_centroid)) | |
y_mean = statistics.mean(map(lambda x: x[1], data_points_in_this_centroid)) | |
centroids[i] = (x_mean, y_mean) | |
global converged | |
if centroids == before_update_centroids: | |
converged = True | |
def k_means(): | |
while not converged: | |
assign_data_points_to_nearest_centroid() | |
adjust_centers() | |
def show_most_changed_data_point(): | |
max_index = 0 | |
for i in range(len(changed)): | |
if changed[max_index] < changed[i]: | |
max_index = i | |
print("Data point {} changed the most; it changed {} times".format(max_index + 1, changed[max_index])) | |
if __name__ == "__main__": | |
k_means() | |
show_most_changed_data_point() |
It uses Python 3
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The gist was created to help people struggling with answering Week 3 Quiz 1 of Clustering and retrieval course. Note the list "changed"