Skip to content

Instantly share code, notes, and snippets.

@timperrett
Created October 22, 2010 10:45
Show Gist options
  • Save timperrett/640325 to your computer and use it in GitHub Desktop.
Save timperrett/640325 to your computer and use it in GitHub Desktop.
K-Means Clustering
import scalala.Scalala._;
import scalala.tensor._;
import scalala.tensor.dense_;
object KMeans {
// Returns the centers of the k clusters, and the clusters themselves
def cluster(data: Seq[Vector], k: Int): Seq[Vector] = {
var means = data.take(k);
var converged = false;
while(!converged) {
val newClusters = data.groupBy { x =>
val distances = means.map{distance(x,_)};
argmin(distances.iterator) // which cluster minimizes euclidean distance
}
val newMeans = Seq.tabulate(k)( i => mean(newClusters(i)));
converged = (newMeans zip means).forall { case (a,b) => distance(a,b) < 1E-4 };
means = newMeans;
}
means;
}
def distance(a: Vector, b: Vector) = norm(a-b,2);
}
// just test code:
import scalanlp.stats.sampling._;
val m1 = Vector(-3,-4);
val m2 = Vector(4,5);
val means = Seq(m1,m2);
val std = 1.0;
// I haven't bothered to make multivariate gaussians.
def gaussian(m: Vector, sig: Double) = new Rand[Vector] {
def draw() = {
new DenseVector(Array.tabulate(m.size)( i => Rand.gaussian(m(i),sig).draw))
}
}
def randPoint = for {
index <- Rand.randInt(2);
m = means(index);
draw <- gaussian(m,std)
} yield (index,draw);
val samples = randPoint.sample(10000);
KMeans.cluster(samples.map(_._2), 2);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment