Skip to content

Instantly share code, notes, and snippets.

@jamesrajendran
Last active May 26, 2017 09:30
Show Gist options
  • Save jamesrajendran/321e9511d47a7fc7848e889041517957 to your computer and use it in GitHub Desktop.
Save jamesrajendran/321e9511d47a7fc7848e889041517957 to your computer and use it in GitHub Desktop.
Machine Learning
# Spark ML example
# Here is the code URL (courtesy): https://github.com/jayantshekhar/strata-2016/blob/master/src/main/scala/com/cloudera/spark/spamdetection/Spam.scala
------------------------
// scalastyle:off println
package com.cloudera.spark.spamdetection
import scala.beans.BeanInfo
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.sql.SQLContext
import com.cloudera.spark.mllib.SparkConfUtil
import scala.reflect.runtime.universe
import org.apache.spark.ml.feature.IDF
@BeanInfo
case class SpamDocument(file: String, text: String, label: Double)
object Spam {
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("Spam")
SparkConfUtil.setConf(conf)
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
// http://www.aueb.gr/users/ion/data/enron-spam/
// read in the spam files
val spamrdd = sc.wholeTextFiles("data/enron/spam", 1)
val spamdf = spamrdd.map(d => SpamDocument(d._1, d._2, 1)).toDF()
spamdf.show()
// read in the ham files
val hamrdd = sc.wholeTextFiles("data/enron/ham", 1)
val hamdf = hamrdd.map(d => SpamDocument(d._1, d._2, 0)).toDF()
hamdf.show()
// all
val alldf = spamdf.unionAll(hamdf)
alldf.show()
val Array(trainingData, testData) = alldf.randomSplit(Array(0.7, 0.3))
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
val tokenizer = new Tokenizer()
.setInputCol("text")
.setOutputCol("words")
val hashingTF = new HashingTF()
.setInputCol(tokenizer.getOutputCol)
.setOutputCol("rawFeatures")
val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
val lr = new LogisticRegression()
.setMaxIter(5)
lr.setLabelCol("label")
lr.setFeaturesCol("features")
val pipeline = new Pipeline()
.setStages(Array(tokenizer, hashingTF, idf, lr))
val lrModel = pipeline.fit(trainingData)
println(lrModel.toString())
// Make predictions.
val predictions = lrModel.transform(testData)
// display the predictions
predictions.select("file", "text", "label", "features", "prediction").show(300)
sc.stop()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment