Skip to content

Instantly share code, notes, and snippets.

@ebuildy
Last active July 9, 2017 19:35
Show Gist options
  • Save ebuildy/cd914857b7a2f6968ab0cfd0d9bb5bef to your computer and use it in GitHub Desktop.
Save ebuildy/cd914857b7a2f6968ab0cfd0d9bb5bef to your computer and use it in GitHub Desktop.
Play with Naive Bayes classification with Apache Spark
import org.apache.spark.ml.feature._
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.classification._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.functions._
val categoryIndexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("category_id")
val tokenizer = new Tokenizer()
.setInputCol("body")
.setOutputCol("words")
val remover = new StopWordsRemover()
.setInputCol("words")
.setOutputCol("words_clean")
.setStopWords(StopWordsRemover.loadDefaultStopWords("french") ++ StopWordsRemover.loadDefaultStopWords("english") ++ Array("les", "cet", "cette", "tout", "ans", "tous", "toutes"))
val tfHasher = new HashingTF()
.setInputCol("words_clean")
.setOutputCol("rawFeatures")
.setNumFeatures(50000) // Adjust according quality you want
val idfMaker = new IDF()
.setInputCol("rawFeatures")
.setOutputCol("features")
val classifier = new NaiveBayes()
.setFeaturesCol("features")
.setLabelCol("category_id")
// Load data from Parquet file, our schema is (body:String, category: String)
val inputDataDF = sqlContext.read.parquet("/Users/tom/bigdata/poc1/news")
.where("category is not null")
.withColumn("body", regexp_replace(lower($"body"), "[^a-zéêàèA-Z]+", " "))
// Split train and test data (to check quality)
val Array(trainingData, testData) = inputDataDF.randomSplit(Array(0.9, 0.1), seed = 1234L)
// Save labels (String <-> Double) to retrieve it later
val labels = categoryIndexer.fit(trainingData).transform(trainingData).select("category", "category_id").dropDuplicates("category")
labels.write.mode("overwrite").parquet("/Users/tom/bigdata/poc/labels")
val pipeline = new Pipeline()
.setStages(Array(categoryIndexer, tokenizer, remover, tfHasher, idfMaker, classifier))
val model = pipeline
.fit(trainingData)
// Save the pipeline
model.write.overwrite.save("/Users/tom/bigdata/poc/classifier")
// Test the pipeline to check the accuracy (> 0.7 should be fine...)
val predictions = model.transform(testData)
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("category_id")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println("Test set accuracy = " + accuracy)
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.Row
val labels = sqlContext.read.parquet("/Users/tom/bigdata/poc/labels")
val model = PipelineModel.load("/Users/tom/bigdata/poc/classifier")
val test = spark.createDataFrame(Seq(
(1,"lady gaga annule son concert"),
(2,"lady gaga annule sa visite"),
(3,"le nouveau film de Luc Besson"),
(4,"Linux a mis à jour ses paquets"),
(5, "la bourse décroche encore à cause des incertitudes sur la zone euro")
)).toDF("id","body")
val predictions = model.transform(test)
predictions
.join(labels, labels.col("category_id") === predictions.col("prediction"))
.select("body", "category")
.collect()
.foreach { case Row( body: String, category: String) =>
println(s"($body) --> $category")
}
(lady gaga annule son concert) --> music
(lady gaga annule sa visite) --> music
(le nouveau film de Luc Besson) --> cinema
(Linux a mis à jour ses paquets) --> video_game
(la bourse décroche encore à cause des incertitudes sur la zone euro) --> economy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment