Skip to content

Instantly share code, notes, and snippets.

@lucidfrontier45
Created March 17, 2016 08:15
Show Gist options
  • Select an option

  • Save lucidfrontier45/c31a69bd97de20961177 to your computer and use it in GitHub Desktop.

Select an option

Save lucidfrontier45/c31a69bd97de20961177 to your computer and use it in GitHub Desktop.
Spark LDA benchmark
package com.frontier45.LDABench
import org.apache.spark.mllib.clustering.LDA
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.{SparkContext, SparkConf}
import scala.util.Random
/**
* Created by du on 3/15/16.
*/
object LDABench {
case class RunArgs(
T: Int = 100000,
W: Int = 10000,
D: Int = 10000,
K: Int = 10,
n_partitions: Int = 100,
n_iterations: Int = 10,
spark_master: Option[String] = None,
optimizer: String = "em",
use_cache: Boolean = false,
checkpoint_interval: Int = 10
)
def main(args: Array[String]): Unit = {
val parser = new scopt.OptionParser[RunArgs]("LDATest") {
opt[Int]('t', "num_trx").valueName("NUM_TRX")
.action((x, c) => c.copy(T = x))
opt[Int]('w', "num_words").valueName("NUM_WORDS")
.action((x, c) => c.copy(W = x))
opt[Int]('d', "num_documents").valueName("NUM_DOCUMENTS")
.action((x, c) => c.copy(D = x))
opt[Int]('k', "num_topics").valueName("NUM_TOPICS")
.text("Number of Topics")
.action((x, c) => c.copy(K = x))
opt[Int]("num-iterations").valueName("NUM_ITERATIONS")
.text("Number of iterations")
.action((x, c) => c.copy(n_iterations = x))
opt[Int]("num-partitions").valueName("NUM_PARTITIONS")
.text("Number of Partitions")
.action((x, c) => c.copy(n_partitions = x))
opt[String]("spark-master").valueName("URL")
.text("Spark Master URL")
.action((x, c) => c.copy(spark_master = Some(x)))
opt[String]("optimizer").valueName("OPTIMIZER")
.action((x, c) => c.copy(optimizer = x))
opt[Unit]("use-cache")
.action((_, c) => c.copy(use_cache = true))
opt[Int]('c', "checkpoint-interval").valueName("Interval")
.action((x, c) => c.copy(checkpoint_interval = x))
}
parser.parse(args, RunArgs())
.map(run)
.getOrElse(System.exit(1))
}
def run(args: RunArgs): Unit = {
val conf = new SparkConf()
conf.getOption("spark.app.name").getOrElse(conf.setAppName("com/freebit/LDABench"))
conf.getOption("spark.master").getOrElse(
conf.setMaster(args.spark_master.getOrElse("local[*]"))
)
implicit val sc = new SparkContext(conf)
startJob(args)
sc.stop()
}
def startJob(args: RunArgs)(implicit sc: SparkContext): Unit = {
val documents = sc.parallelize(1 to args.n_partitions)
.repartition(args.n_partitions)
.flatMap { i =>
val r = new Random()
val itr = 1 to (args.T / args.n_partitions)
itr.toStream.map(n => (r.nextInt(args.D), r.nextInt(args.W), r.nextInt(10)))
}.map { case (d, w, c) => ((d, w), c) }
.reduceByKey(_ + _)
.map { case ((d, w), c) => (d, (w, c.toDouble)) }
.groupByKey()
.map(t => (t._1.toLong, Vectors.sparse(args.W, t._2.toSeq)))
.repartition(args.n_partitions)
if (args.use_cache) {
documents.cache()
val c = documents.count()
println("count = %d".format(c))
}
val lda = new LDA()
.setK(args.K)
.setMaxIterations(args.n_iterations)
.setOptimizer(args.optimizer)
.setCheckpointInterval(args.checkpoint_interval)
lda.run(documents)
println("OK")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment