Skip to content

Instantly share code, notes, and snippets.

@oktomus
Created November 18, 2017 20:34
Show Gist options
  • Save oktomus/7b15699a20088475f10793db1799710d to your computer and use it in GitHub Desktop.
Save oktomus/7b15699a20088475f10793db1799710d to your computer and use it in GitHub Desktop.
Python kmean implementation
#!/usr/bin/python
# Kmean on a CSV data set
# 2016
import sys
import csv
if len(sys.argv) < 2:
print "USAGE : %s csv-file" % (sys.argv[0])
def square4DistanceTo(point1, point2):
return (pow(abs(point1[0] - point2[0]), 2) +
pow(abs(point1[1] - point2[1]), 2) +
pow(abs(point1[2] - point2[2]), 2) +
pow(abs(point1[3] - point2[3]), 2))
filename = sys.argv[1]
nbCluster = 3
if len(sys.argv) > 2:
nbCluster = int(sys.argv[2])
data = []
# Parse csv
with open(filename, 'rb') as csvfile:
reader = csv.reader(csvfile, delimiter=',')
for row in reader:
# Parse to float
for i in range(0, 4):
row[i] = float(row[i])
data.append(row)
# Chaque cluster contient la liste des index des donnees qui leur sont ratachees
# Pour les centres initiaux, c'est arbitraire
clusters = [[] for i in range(0, nbCluster)]
centre_clusters = [[data[i][0], data[i][1], data[i][2], data[i][3]] for i in range(0, nbCluster)]
modification = True
while modification:
# Assignation des points dans les clusters
clusters = [[] for i in range(0, nbCluster)] # Reset clusters
rowIndex = 0
for row in data: # For each point
# Trouver le centre cluster le plus proche
nearest = 0
nearestDistance = square4DistanceTo(centre_clusters[0], row)
if nearestDistance > 0:
for i in range(1, nbCluster):
distance = square4DistanceTo(centre_clusters[i], row)
if distance < 0:
print distance
if distance < nearestDistance :
nearest = i
nearestDistance = distance
# Assign point to cluster
clusters[nearest].append(rowIndex)
rowIndex += 1
modification = False
# Calcul des moyennes des clusters
for i in range(0, nbCluster):
# Calcul moyenne cluster i
nbRows = len(clusters[i])
totalSl = 0.0
totalSw = 0.0
totalPl = 0.0
totalPw = 0.0
for rowIndex in clusters[i]:
row = data[rowIndex]
totalSl += row[0]
totalSw += row[1]
totalPl += row[2]
totalPw += row[3]
centre = [totalSl / nbRows,
totalSw / nbRows,
totalPl / nbRows,
totalPw / nbRows]
if not modification:
prCentre = centre_clusters[i]
if ((prCentre[0] != centre[0]) or
(prCentre[1] != centre[1]) or
(prCentre[2] != centre[2]) or
(prCentre[3] != centre[3])):
modification = True
centre_clusters[i] = centre
with open("out.csv", "wb") as csvfile:
writer = csv.writer(csvfile, delimiter = ",")
for i in range(0, nbCluster):
for rowIndex in clusters[i]:
row = data[rowIndex]
row.insert(0, i + 1)
writer.writerow(row)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment