-
-
Save jamesrajendran/321e9511d47a7fc7848e889041517957 to your computer and use it in GitHub Desktop.
Machine Learning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Spark ML example | |
# Here is the code URL (courtesy): https://github.com/jayantshekhar/strata-2016/blob/master/src/main/scala/com/cloudera/spark/spamdetection/Spam.scala | |
------------------------ | |
// scalastyle:off println | |
package com.cloudera.spark.spamdetection | |
import scala.beans.BeanInfo | |
import org.apache.spark.{SparkConf, SparkContext} | |
import org.apache.spark.ml.Pipeline | |
import org.apache.spark.ml.classification.LogisticRegression | |
import org.apache.spark.ml.feature.{HashingTF, Tokenizer} | |
import org.apache.spark.sql.SQLContext | |
import com.cloudera.spark.mllib.SparkConfUtil | |
import scala.reflect.runtime.universe | |
import org.apache.spark.ml.feature.IDF | |
@BeanInfo | |
case class SpamDocument(file: String, text: String, label: Double) | |
object Spam { | |
def main(args: Array[String]) { | |
val conf = new SparkConf().setAppName("Spam") | |
SparkConfUtil.setConf(conf) | |
val sc = new SparkContext(conf) | |
val sqlContext = new SQLContext(sc) | |
import sqlContext.implicits._ | |
// http://www.aueb.gr/users/ion/data/enron-spam/ | |
// read in the spam files | |
val spamrdd = sc.wholeTextFiles("data/enron/spam", 1) | |
val spamdf = spamrdd.map(d => SpamDocument(d._1, d._2, 1)).toDF() | |
spamdf.show() | |
// read in the ham files | |
val hamrdd = sc.wholeTextFiles("data/enron/ham", 1) | |
val hamdf = hamrdd.map(d => SpamDocument(d._1, d._2, 0)).toDF() | |
hamdf.show() | |
// all | |
val alldf = spamdf.unionAll(hamdf) | |
alldf.show() | |
val Array(trainingData, testData) = alldf.randomSplit(Array(0.7, 0.3)) | |
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. | |
val tokenizer = new Tokenizer() | |
.setInputCol("text") | |
.setOutputCol("words") | |
val hashingTF = new HashingTF() | |
.setInputCol(tokenizer.getOutputCol) | |
.setOutputCol("rawFeatures") | |
val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features") | |
val lr = new LogisticRegression() | |
.setMaxIter(5) | |
lr.setLabelCol("label") | |
lr.setFeaturesCol("features") | |
val pipeline = new Pipeline() | |
.setStages(Array(tokenizer, hashingTF, idf, lr)) | |
val lrModel = pipeline.fit(trainingData) | |
println(lrModel.toString()) | |
// Make predictions. | |
val predictions = lrModel.transform(testData) | |
// display the predictions | |
predictions.select("file", "text", "label", "features", "prediction").show(300) | |
sc.stop() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment