Last active December 29, 2015 07:45
Bisecting k-means for hierarchical clustering in Spark
* bisecting <master> <input> <nNodes> <subIterations>
* divisive hierarchical clustering using bisecting k-means
* assumes input is a text file, each row is a data point
* given as numbers separated by spaces
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import scala.collection.mutable.ArrayBuffer
case class Cluster(var key: Int, var center: Array[Double], var children: Option[List[Cluster]])
object bisecting {
/** recursively search cluster tree for desired key and insert children **/
def insert(node: Cluster, key: Int, children: List[Cluster]) {
if (node.key == key) {
node.children = Some(children)
} else {
if (node.children.getOrElse(0) != 0) {
/** parse text into vectors **/
def parseVector(line: String): Vector = {
val vec = line.split(' ').map(_.toDouble)
/** simple squared distance between two vectors **/
def squaredDist(vec1: Vector, vec2: Vector) = {{ case (x, y) => math.pow(x - y, 2)}.sum
/** add two vectors **/
def plus(vec1: Vector, vec2: Vector) = {
Vectors.dense({ case (x, y) => x + y})
/** divide a vector by a scalar **/
def divide(vec1: Vector, s: Double) = {
Vectors.dense( (x => x / s))
/** closest point between a vector and an array of two vectors **/
def closestPoint(p: Vector, centers: Array[Vector]): Int = {
if (squaredDist(p, centers(0)) < squaredDist(p, centers(1))) {
else {
/** split into two clusters using kmeans **/
def split(cluster: RDD[Vector], subIterations: Int): Array[Vector] = {
val convergeDist = 0.001
var best = Double.PositiveInfinity
var centersFinal = cluster.takeSample(false, 2, 1).toArray
// try multiple splits and keep the best
for (iter <- 0 until subIterations) {
val centers = cluster.takeSample(false, 2, iter).toArray
var tempDist = 1.0
// do iterations of k-means till convergence
while(tempDist > convergeDist) {
val closest = => (closestPoint(p, centers), (p, 1)))
val pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (plus(x1, x2), y1 + y2)}
val newPoints ={pair => (pair._1, divide(pair._2._1, pair._2._2.toDouble))}.collectAsMap()
tempDist = 0.0
for (i <- 0 until 2) {
tempDist += squaredDist(centers(i), newPoints(i))
for (newP <- newPoints) {
centers(newP._1) = newP._2
// check within-class similarity of clusters
var totalDist = 0.0
for (i <- 0 until 2) {
totalDist += cluster.filter(x => closestPoint(x,centers)==i).map(x => squaredDist(x, centers(i))).reduce(_+_)
if (totalDist < best) {
best = totalDist
centersFinal = centers
def main(args: Array[String]) {
if (args.length < 4) {
System.err.println("Usage: bisecting <master> <input> <nNodes> <subIterations>")
// collect arguments
val master = args(0)
val input = args(1)
val nNodes = args(2).toDouble
val subIterations = args(3).toInt
// create spark context
val sc = new SparkContext(master, "bisecting")
// load raw data
val data = sc.textFile(input).map(parseVector)
// create array buffer with first cluster and compute its center
val clusters = ArrayBuffer((0,data))
val n = data.count()
val center = data.reduce{case (x, y) => plus(x, y)} => x / n)
val tree = Cluster(0, center, None)
var count = 1
// start timer
val startTime = System.nanoTime
while (clusters.size < nNodes) {
println(clusters.size.toString + " clusters, starting new iteration")
// find largest cluster for splitting
val ind =
// split into 2 clusters using k-means
val centers = split(clusters(ind)._2, subIterations) // find 2 cluster centers
val cluster1 = clusters(ind)._2.filter(x => closestPoint(x, centers) == 0)
val cluster2 = clusters(ind)._2.filter(x => closestPoint(x, centers) == 1)
// get new indices
val newInd1 = count
val newInd2 = count + 1
count += 2
// update tree with results
insert(tree, clusters(ind)._1, List(
Cluster(newInd1, centers(0).toArray, None),
Cluster(newInd2, centers(1).toArray, None)))
// remove old cluster, add the 2 new ones
clusters.append((newInd1, cluster1))
clusters.append((newInd2, cluster2))
println("Bisecting took: "+(System.nanoTime-startTime)/1e9+"s")
Copy link

yu-iskw commented Sep 25, 2014

Great work. If you have benchmark result of this code, could you share it with me?
I am trying to implement a top-down hierarchical clustering like this. However, it is very slow...

