Created
October 22, 2010 10:45
-
-
Save timperrett/640325 to your computer and use it in GitHub Desktop.
K-Means Clustering
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scalala.Scalala._; | |
import scalala.tensor._; | |
import scalala.tensor.dense_; | |
object KMeans { | |
// Returns the centers of the k clusters, and the clusters themselves | |
def cluster(data: Seq[Vector], k: Int): Seq[Vector] = { | |
var means = data.take(k); | |
var converged = false; | |
while(!converged) { | |
val newClusters = data.groupBy { x => | |
val distances = means.map{distance(x,_)}; | |
argmin(distances.iterator) // which cluster minimizes euclidean distance | |
} | |
val newMeans = Seq.tabulate(k)( i => mean(newClusters(i))); | |
converged = (newMeans zip means).forall { case (a,b) => distance(a,b) < 1E-4 }; | |
means = newMeans; | |
} | |
means; | |
} | |
def distance(a: Vector, b: Vector) = norm(a-b,2); | |
} | |
// just test code: | |
import scalanlp.stats.sampling._; | |
val m1 = Vector(-3,-4); | |
val m2 = Vector(4,5); | |
val means = Seq(m1,m2); | |
val std = 1.0; | |
// I haven't bothered to make multivariate gaussians. | |
def gaussian(m: Vector, sig: Double) = new Rand[Vector] { | |
def draw() = { | |
new DenseVector(Array.tabulate(m.size)( i => Rand.gaussian(m(i),sig).draw)) | |
} | |
} | |
def randPoint = for { | |
index <- Rand.randInt(2); | |
m = means(index); | |
draw <- gaussian(m,std) | |
} yield (index,draw); | |
val samples = randPoint.sample(10000); | |
KMeans.cluster(samples.map(_._2), 2); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment