Skip to content

Instantly share code, notes, and snippets.

@mitallast
Created February 5, 2016 12:07
Show Gist options
  • Save mitallast/cc4fab310e1758a966a8 to your computer and use it in GitHub Desktop.
Save mitallast/cc4fab310e1758a966a8 to your computer and use it in GitHub Desktop.
Example app to classify sells
data size 50k
num features cv f1 metric
1000 0.751825336615163
2000 0.813659787214313
4000 0.854789320437248
6000 0.872333611210510
8000 0.882095086103654
10000 0.887525006704344
20000 0.897327637816261
30000 0.899292503994883
40000 0.899953381840675
50000 0.900651888235837
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{HashingTF, StringIndexer, Tokenizer}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
object SimpleApp {
def main(args: Array[String]) {
System.setProperty("spark.executor.memory", "10G")
val conf = new SparkConf()
.setAppName("Test Naive Bayes")
.setMaster("spark://127.0.0.1:7077")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
val data = sqlContext.read.load("/Users/mitallast/Sites/spark/sell.parquet")
val Array(train, test) = data.randomSplit(Array(0.8, 0.2))
train.cache()
test.cache()
val categoryIndexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("label")
val tokenizer = new Tokenizer()
.setInputCol("text")
.setOutputCol("words")
val hashingTF = new HashingTF()
.setInputCol("words")
.setOutputCol("features")
.setNumFeatures(10000)
val classifier = new NaiveBayes()
.setSmoothing(1.0)
.setModelType("multinomial")
val pipeline = new Pipeline()
.setStages(Array(categoryIndexer, tokenizer, hashingTF, classifier))
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("f1")
val grid = new ParamGridBuilder()
.addGrid(hashingTF.numFeatures, Array(10000, 50000, 100000))
.build()
val crossValidator = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(grid)
.setNumFolds(3)
val model = crossValidator.fit(train)
val pr = model.transform(test)
val metric = evaluator.evaluate(pr)
println("F1 metric: %f".format(metric))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment