Skip to content

Instantly share code, notes, and snippets.

@yaravind
Last active July 3, 2020 23:16
Show Gist options
  • Save yaravind/3847afbd87d9f25ef9ae4ccf79b5f39d to your computer and use it in GitHub Desktop.
Save yaravind/3847afbd87d9f25ef9ae4ccf79b5f39d to your computer and use it in GitHub Desktop.
SparkML to MLLib conversion to run BisectingKMeans clustering
import org.apache.spark.mllib.clustering.BisectingKMeans
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.Vector
//std_features col is of type vector
scaledFeatures.select($"std_features").printSchema()
val tempFeatureRdd = scaledFeatures.select($"std_features").rdd
import scala.reflect.runtime.universe._
def getType[T: TypeTag](value: T) = typeOf[T]
println("-------BEFORE")
println("Type of RDD: "+getType(tempFeatureRdd))
println("Type of column: "+getType(tempFeatureRdd.first()))
/**
create a new df of type RDD[org.apache.spark.mllib.linalg.Vector] by mapping
RDD[org.apache.spark.sql.Row] to RDD[org.apache.spark.mllib.linalg.Vector]
as BisectingKMeans works only with Vector type
**/
val input = scaledFeatures
.select($"std_features")
.rdd
.map(v => Vectors.fromML(v.getAs[org.apache.spark.ml.linalg.Vector](0)))
.cache() //important for ML algos to run faster
println("-------AFTER")
println("Type of RDD: "+getType(input))
println("Type of column: "+getType(input.first()))
println("Total rows: "+input.count())
// Clustering the data into 9 clusters by BisectingKMeans.
val bkm = new BisectingKMeans().setK(9)
val model = bkm.run(input)
println(s"Compute Cost: ${model.computeCost(input)}")
model.clusterCenters.zipWithIndex.foreach { case (center, idx) =>
println(s"Cluster Center ${idx}: ${center}")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment