Created
February 5, 2016 12:07
-
-
Save mitallast/cc4fab310e1758a966a8 to your computer and use it in GitHub Desktop.
Example app to classify sells
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
data size 50k | |
num features cv f1 metric | |
1000 0.751825336615163 | |
2000 0.813659787214313 | |
4000 0.854789320437248 | |
6000 0.872333611210510 | |
8000 0.882095086103654 | |
10000 0.887525006704344 | |
20000 0.897327637816261 | |
30000 0.899292503994883 | |
40000 0.899953381840675 | |
50000 0.900651888235837 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.ml.Pipeline | |
import org.apache.spark.ml.classification.NaiveBayes | |
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator | |
import org.apache.spark.ml.feature.{HashingTF, StringIndexer, Tokenizer} | |
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} | |
import org.apache.spark.sql.SQLContext | |
import org.apache.spark.{SparkConf, SparkContext} | |
object SimpleApp { | |
def main(args: Array[String]) { | |
System.setProperty("spark.executor.memory", "10G") | |
val conf = new SparkConf() | |
.setAppName("Test Naive Bayes") | |
.setMaster("spark://127.0.0.1:7077") | |
val sc = new SparkContext(conf) | |
val sqlContext = new SQLContext(sc) | |
val data = sqlContext.read.load("/Users/mitallast/Sites/spark/sell.parquet") | |
val Array(train, test) = data.randomSplit(Array(0.8, 0.2)) | |
train.cache() | |
test.cache() | |
val categoryIndexer = new StringIndexer() | |
.setInputCol("category") | |
.setOutputCol("label") | |
val tokenizer = new Tokenizer() | |
.setInputCol("text") | |
.setOutputCol("words") | |
val hashingTF = new HashingTF() | |
.setInputCol("words") | |
.setOutputCol("features") | |
.setNumFeatures(10000) | |
val classifier = new NaiveBayes() | |
.setSmoothing(1.0) | |
.setModelType("multinomial") | |
val pipeline = new Pipeline() | |
.setStages(Array(categoryIndexer, tokenizer, hashingTF, classifier)) | |
val evaluator = new MulticlassClassificationEvaluator() | |
.setLabelCol("label") | |
.setPredictionCol("prediction") | |
.setMetricName("f1") | |
val grid = new ParamGridBuilder() | |
.addGrid(hashingTF.numFeatures, Array(10000, 50000, 100000)) | |
.build() | |
val crossValidator = new CrossValidator() | |
.setEstimator(pipeline) | |
.setEvaluator(evaluator) | |
.setEstimatorParamMaps(grid) | |
.setNumFolds(3) | |
val model = crossValidator.fit(train) | |
val pr = model.transform(test) | |
val metric = evaluator.evaluate(pr) | |
println("F1 metric: %f".format(metric)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment