Created
November 18, 2017 20:34
-
-
Save oktomus/7b15699a20088475f10793db1799710d to your computer and use it in GitHub Desktop.
Python kmean implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# Kmean on a CSV data set | |
# 2016 | |
import sys | |
import csv | |
if len(sys.argv) < 2: | |
print "USAGE : %s csv-file" % (sys.argv[0]) | |
def square4DistanceTo(point1, point2): | |
return (pow(abs(point1[0] - point2[0]), 2) + | |
pow(abs(point1[1] - point2[1]), 2) + | |
pow(abs(point1[2] - point2[2]), 2) + | |
pow(abs(point1[3] - point2[3]), 2)) | |
filename = sys.argv[1] | |
nbCluster = 3 | |
if len(sys.argv) > 2: | |
nbCluster = int(sys.argv[2]) | |
data = [] | |
# Parse csv | |
with open(filename, 'rb') as csvfile: | |
reader = csv.reader(csvfile, delimiter=',') | |
for row in reader: | |
# Parse to float | |
for i in range(0, 4): | |
row[i] = float(row[i]) | |
data.append(row) | |
# Chaque cluster contient la liste des index des donnees qui leur sont ratachees | |
# Pour les centres initiaux, c'est arbitraire | |
clusters = [[] for i in range(0, nbCluster)] | |
centre_clusters = [[data[i][0], data[i][1], data[i][2], data[i][3]] for i in range(0, nbCluster)] | |
modification = True | |
while modification: | |
# Assignation des points dans les clusters | |
clusters = [[] for i in range(0, nbCluster)] # Reset clusters | |
rowIndex = 0 | |
for row in data: # For each point | |
# Trouver le centre cluster le plus proche | |
nearest = 0 | |
nearestDistance = square4DistanceTo(centre_clusters[0], row) | |
if nearestDistance > 0: | |
for i in range(1, nbCluster): | |
distance = square4DistanceTo(centre_clusters[i], row) | |
if distance < 0: | |
print distance | |
if distance < nearestDistance : | |
nearest = i | |
nearestDistance = distance | |
# Assign point to cluster | |
clusters[nearest].append(rowIndex) | |
rowIndex += 1 | |
modification = False | |
# Calcul des moyennes des clusters | |
for i in range(0, nbCluster): | |
# Calcul moyenne cluster i | |
nbRows = len(clusters[i]) | |
totalSl = 0.0 | |
totalSw = 0.0 | |
totalPl = 0.0 | |
totalPw = 0.0 | |
for rowIndex in clusters[i]: | |
row = data[rowIndex] | |
totalSl += row[0] | |
totalSw += row[1] | |
totalPl += row[2] | |
totalPw += row[3] | |
centre = [totalSl / nbRows, | |
totalSw / nbRows, | |
totalPl / nbRows, | |
totalPw / nbRows] | |
if not modification: | |
prCentre = centre_clusters[i] | |
if ((prCentre[0] != centre[0]) or | |
(prCentre[1] != centre[1]) or | |
(prCentre[2] != centre[2]) or | |
(prCentre[3] != centre[3])): | |
modification = True | |
centre_clusters[i] = centre | |
with open("out.csv", "wb") as csvfile: | |
writer = csv.writer(csvfile, delimiter = ",") | |
for i in range(0, nbCluster): | |
for rowIndex in clusters[i]: | |
row = data[rowIndex] | |
row.insert(0, i + 1) | |
writer.writerow(row) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment