Last active
March 1, 2019 01:19
-
-
Save efruchter/7c4816f0d9083312ca5d967ac3d79a22 to your computer and use it in GitHub Desktop.
Very basic KMeans in C#
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public static int[] KMeansCluster(int clusterCount, float3[] data, int maxIterations = 1000) | |
{ | |
if (clusterCount <= 0) | |
{ | |
return new int[0]; | |
} | |
int[] countWithinClusters = new int[clusterCount]; | |
countWithinClusters.Fill(0); | |
int[] clusterAssignments = new int[data.Length]; | |
// Assign groups in a fair but deterministic way | |
for (int d = 0; d < data.Length; d++) | |
{ | |
int cluster = d % clusterCount; | |
clusterAssignments[d] = cluster; | |
countWithinClusters[cluster]++; | |
} | |
// Early exit in trivial case | |
if (data.Length < clusterCount) | |
{ | |
return clusterAssignments; | |
} | |
float3[] clusterCentroids = new float3[clusterCount]; | |
bool converged = false; | |
int iterationsRemaining = maxIterations; | |
// Iterate until no cluster has changed occupancy, or we hit the limit | |
while (!converged && iterationsRemaining > 0) | |
{ | |
// Find new mean of each cluster | |
{ | |
clusterCentroids.Fill(float3.zero); | |
for (int d = 0; d < data.Length; d++) | |
{ | |
clusterCentroids[clusterAssignments[d]] += data[d]; | |
} | |
for (int c = 0; c < clusterCount; c++) | |
{ | |
if (countWithinClusters[c] > 0) | |
{ | |
clusterCentroids[c] /= countWithinClusters[c]; | |
} | |
} | |
} | |
// Reset Counters and tick iterators | |
countWithinClusters.Fill(0); | |
converged = true; | |
iterationsRemaining--; | |
// Assign each data point to its nearest cluster | |
for (int d = 0; d < data.Length; d++) | |
{ | |
float closestDist = float.PositiveInfinity; | |
int closestCluster = -1; | |
for (int c = 0; c < clusterCentroids.Length; c++) | |
{ | |
float dist = math.lengthsq(data[d] - clusterCentroids[c]); | |
if (dist < closestDist) | |
{ | |
closestDist = dist; | |
closestCluster = c; | |
} | |
} | |
if (closestCluster != clusterAssignments[d]) | |
{ | |
clusterAssignments[d] = closestCluster; | |
converged = false; | |
} | |
countWithinClusters[closestCluster] += 1; | |
} | |
} | |
return clusterAssignments; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment