Created
October 23, 2014 23:07
-
-
Save azymnis/a1718e5cb014d168ae6f to your computer and use it in GitHub Desktop.
K-Means in scalding
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.twitter.algebird.{Aggregator, Semigroup} | |
import com.twitter.scalding._ | |
import scala.util.Random | |
/** | |
* This job is a tutorial of sorts for scalding's Execution[T] abstraction. | |
* It is a simple implementation of Lloyd's algorithm for k-means on 2D data. | |
* | |
* http://en.wikipedia.org/wiki/K-means_clustering | |
* | |
* The assumption here is that the number of clusters is much smaller than the | |
* number of points, and the cluster centroids can easily fit in memory. | |
* | |
* Input: a text TSV file with two columns, where each row corresponds to | |
* the (x, y) coordinates of all points in the dataset. | |
* | |
* Output: a three column TSV where the first two columns are the (x, y) | |
* coordinates of each point and the last column is the final cluster id. | |
* | |
* E.g. to cluster data into 3 clusters: | |
* | |
* scald --local KMeansJob.scala --clusters 3 --input input.tsv --output output.tsv | |
* | |
*/ | |
class KMeansJob(args: Args) extends ExecutionJob[Unit](args) { | |
import TDsl._ | |
// Number of clusters to partition the data into | |
val numClusters = args("clusters").toInt | |
val inputFile = args("input") | |
val outputFile = args("output") | |
implicit val pointSemigroup = PointSemigroup | |
// Reads in points and allocates them to random clusters | |
val initialPoints: TypedPipe[Point] = | |
TypedTsv[(Double, Double)](inputFile) | |
.map { case (x, y) => Point(x, y, Random.nextInt(numClusters)) } | |
// This represents a single step in the kMeans iteration. | |
// Given the current cluster assignments, we first calculate the mean of each cluster. | |
// We then allocate each point to its closest cluster. | |
// The output of this step is an Execution which wraps a tuple of the number | |
// of changed points, as well as the new cluster assignments. | |
def kMeansStep(currentPoints: TypedPipe[Point]): Execution[(Long, TypedPipe[Point])] = { | |
// Update step: calculate the new centroid of each cluster. | |
// Wrap that into an Execution[Iterable[Point]] | |
val newClusterMeansExecution = currentPoints | |
.map { p => (p.cluster, (1L, p)) } | |
.group | |
.sum | |
.map { case (cluster, (count, summed)) => summed.normalize(count) } | |
.toIterableExecution | |
newClusterMeansExecution.flatMap { means => | |
val meanSeq = means.toSeq | |
// Assignment step: assign each point to the cluster that corresponds | |
// to the closest centroid. | |
val newPointsWithDeltas = currentPoints.map { p => | |
val newClosestCluster = p.closestPoint(meanSeq).cluster | |
val newPoint = p.copy(cluster = newClosestCluster) | |
if (p.cluster == newClosestCluster) { | |
(newPoint, 0L) | |
} else { | |
(newPoint, 1L) | |
} | |
} | |
val newPoints = newPointsWithDeltas.map { case (point, _) => point } | |
newPointsWithDeltas | |
.map { case (point, delta) => delta } | |
.aggregate(Aggregator.fromSemigroup[Long]) | |
.toOptionExecution | |
.map { deltaOpt => | |
val delta = deltaOpt.getOrElse(0L) | |
(delta, newPoints) | |
} | |
} | |
} | |
// This recursive method first performs an update and then checks if any | |
// points have changed. If this is true, it performs a further update. | |
// Otherwise we exit the recursion with the final solution. | |
def updateAndCheck(points: TypedPipe[Point]): Execution[TypedPipe[Point]] = { | |
kMeansStep(points).flatMap { | |
case (count, newPoints) if (count == 0L) => Execution.from(newPoints) | |
case (count, newPoints) => { | |
System.out.println("%d points changed".format(count)) | |
updateAndCheck(newPoints) | |
} | |
} | |
} | |
// Run the recursion and finally write the result to disk | |
override def execution = { | |
updateAndCheck(initialPoints).flatMap { points => | |
points | |
.map { p => (p.x, p.y, p.cluster) } | |
.writeExecution(TypedTsv[(Double, Double, Int)](outputFile)) | |
} | |
} | |
} | |
case class Point(x: Double, y: Double, cluster: Int) { | |
def normalize(count: Long): Point = Point(x / count, y / count, cluster) | |
def squareDistanceFrom(other: Point): Double = | |
math.pow(this.x - other.x, 2) + math.pow(this.y - other.y, 2) | |
def closestPoint(points: Seq[Point]): Point = points | |
.map { p => (p, squareDistanceFrom(p)) } | |
.sortWith { _._2 <= _._2 } | |
.head._1 | |
} | |
object PointSemigroup extends Semigroup[Point] { | |
def plus(l: Point, r: Point) = | |
Point(l.x + r.x, l.y + r.y, l.cluster) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment