Created
February 1, 2017 01:01
-
-
Save dynamicguy/e0b07b4ced1c66ac7a52476b6ab7517e to your computer and use it in GitHub Desktop.
SVM Classifier with SGD
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.ferabb.spark.app.ml | |
import com.ferabb.spark.SparkApp | |
import com.ferabb.spark.analysis.LuceneTextAnalyzer | |
import com.ferabb.spark.fusion.FusionMLModelSupport | |
import org.apache.commons.cli.{CommandLine, Option} | |
import org.apache.spark.SparkConf | |
import org.apache.spark.mllib.classification.SVMWithSGD | |
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics | |
import org.apache.spark.mllib.feature.{HashingTF, Normalizer, StandardScaler} | |
import org.apache.spark.mllib.regression.LabeledPoint | |
import org.apache.spark.sql.types.{StringType, StructField, StructType} | |
import org.apache.spark.sql.{Row, SaveMode, SparkSession} | |
import org.apache.spark.storage.StorageLevel | |
import scala.collection.JavaConverters._ | |
import scala.collection.immutable | |
/** | |
* A SVM classifier for text classification. This class has the following feature transformations | |
* LuceneAnalyzer | |
* HashingTF | |
* Normalizer | |
* StandardScaler | |
* The SVM Model is trained after the above transformations and is tested for the given test data (using same transformations). | |
*/ | |
object SVMClassifier { | |
val DEFAULT_NUM_FEATURES = "1000000" | |
val DEFAULT_NUM_ITERATIONS = "200" | |
val DefaultZkHost = "localhost:9983" | |
val DefaultCollection = "twitter_sentiment" | |
} | |
class SVMClassifier extends SparkApp.RDDProcessor { | |
import SVMClassifier._ | |
def getName = "mllib-svm-scala" | |
def getOptions = Array( | |
Option.builder().longOpt("indexTrainingData").hasArg.required(false).desc( | |
s"Path to training data to index").build(), | |
Option.builder().longOpt("indexTestData").hasArg.required(false).desc( | |
s"Path to test data to index").build(), | |
Option.builder().longOpt("sample").hasArg.required(false).desc( | |
s"Fraction (0 to 1) of full dataset to sample from Solr, default is 1").build(), | |
Option.builder().longOpt("numFeatures").hasArg.required(false).desc( | |
s"Number of features; default is $DEFAULT_NUM_FEATURES").build(), | |
Option.builder().longOpt("numIterations").hasArg.required(false).desc( | |
s"Number of iterations; default is $DEFAULT_NUM_ITERATIONS").build(), | |
Option.builder().longOpt("modelOutput").hasArg.required(false).desc( | |
s"Model output path; default is mllib-svm-sentiment").build(), | |
Option.builder().longOpt("fusionHostAndPort").hasArg.required(false).desc( | |
s"Fusion host and port; Example localhost:8764").build(), | |
Option.builder().longOpt("fusionUser").hasArg.required(false).desc( | |
s"Fusion user name").build(), | |
Option.builder().longOpt("fusionPassword").hasArg.required(false).desc( | |
s"Fusion password").build(), | |
Option.builder().longOpt("fusionRealm").hasArg.required(false).desc( | |
s"Fusion Realm").build() | |
) | |
override def run(conf: SparkConf, cli: CommandLine): Int = { | |
val sparkSession: SparkSession = SparkSession.builder().config(conf).getOrCreate() | |
val csvSchema = StructType(StructField("polarity",StringType, true) :: | |
StructField("id",StringType, true) :: | |
StructField("date",StringType, true) :: | |
StructField("query",StringType, true) :: | |
StructField("username",StringType, true) :: | |
StructField("tweet_txt",StringType, true) :: Nil) | |
val writeoptions = immutable.HashMap( | |
"zkhost" -> cli.getOptionValue("zkHost", DefaultZkHost), | |
"collection" -> cli.getOptionValue("collection", DefaultCollection), | |
"soft_commit_secs" -> "10") | |
val indexTrainingData = cli.getOptionValue("indexTrainingData") | |
if (indexTrainingData != null) { | |
var csvDF = sparkSession.read.format("com.databricks.spark.csv").schema(csvSchema).option("header", "false").load(indexTrainingData) | |
csvDF = csvDF.repartition(4) | |
csvDF.write.format("solr").options(writeoptions).mode(SaveMode.Overwrite).save() | |
} | |
val indexTestData = cli.getOptionValue("indexTestData"); | |
if (indexTestData != null) { | |
var csvDF = sparkSession.read.format("com.databricks.spark.csv").schema(csvSchema).option("header", "false").load(indexTestData) | |
csvDF = csvDF.withColumnRenamed("polarity", "test_polarity") | |
csvDF.write.format("solr").options(writeoptions).mode(SaveMode.Overwrite).save() | |
} | |
val contentFields = "tweet_txt" | |
var trainoptions = immutable.HashMap( | |
"zkhost" -> cli.getOptionValue("zkHost", DefaultZkHost), | |
"collection" -> cli.getOptionValue("collection", DefaultCollection), | |
"query" -> "+polarity:(0 OR 4) +tweet_txt:[* TO *]", | |
"fields" -> "id,polarity,tweet_txt", | |
"rows" -> "10000", | |
"splits" -> "true", | |
"split_field" -> "_version_", | |
"splits_per_shard" -> "8") | |
val sampleFraction = cli.getOptionValue("sample", "1.0").toDouble | |
var trainingDataFromSolr = sparkSession.read.format("solr").options(trainoptions).load() | |
trainingDataFromSolr = trainingDataFromSolr.sample(false, sampleFraction) | |
val inputCols = contentFields.split(" ").map(_.trim) | |
val stdTokLowerSchema = "{ \"analyzers\": [{ \"name\": \"std_tok_lower\", \"tokenizer\": { \"type\": \"standard\" },\n" + | |
" \"filters\": [{ \"type\": \"lowercase\" }]}],\n" + | |
" \"fields\": [{ \"regex\": \".+\", \"analyzer\": \"std_tok_lower\" }]}\n" | |
val numFeatures = cli.getOptionValue("numFeatures", DEFAULT_NUM_FEATURES).toInt | |
val numIterations = cli.getOptionValue("numIterations", DEFAULT_NUM_ITERATIONS).toInt | |
def RowtoLab(row: Row, numFeatures: Int, inputCols: Array[String], stdTokLowerSchema: String ): LabeledPoint = { | |
var textAnalyzer: LuceneTextAnalyzer = new LuceneTextAnalyzer(stdTokLowerSchema) | |
var hashingTF = new HashingTF(numFeatures) | |
var normalizer = new Normalizer() | |
val polarity = row.getString(row.fieldIndex("polarity")) | |
var fields = new java.util.HashMap[String, String]() | |
for(i <- 0 until inputCols.length){ | |
val value = row.getString(row.fieldIndex(inputCols(i))) | |
if (value != null) { | |
fields.put(inputCols(i), value) | |
} | |
} | |
val analyzedFields = textAnalyzer.analyzeJava(fields) | |
var terms = new java.util.LinkedList[String]() | |
analyzedFields.values().asScala.toList.foreach(v => terms.addAll(v)) | |
val sentimentLabel = if (("0" == polarity)) 0.toDouble else 1.toDouble | |
new LabeledPoint(sentimentLabel, normalizer.transform(hashingTF.transform(terms))) | |
} | |
var trainingData = trainingDataFromSolr.rdd.map(row => RowtoLab(row, numFeatures, inputCols, stdTokLowerSchema)) | |
var standardScaler = new StandardScaler().fit(trainingData.map(x => x.features)) | |
var trainRDD = trainingData.map(x => new LabeledPoint(x.label, standardScaler.transform(x.features))) | |
trainRDD = trainRDD.persist(StorageLevel.MEMORY_ONLY_SER) | |
val model = SVMWithSGD.train(trainRDD, numIterations) | |
var testoptions = immutable.HashMap("zkhost" -> cli.getOptionValue("zkHost", DefaultZkHost), | |
"collection" -> cli.getOptionValue("collection", DefaultCollection), | |
"query" -> "+test_polarity:[* TO *] +tweet_txt:[* TO *]", | |
"fields" -> "id,test_polarity,tweet_txt") | |
var testDataFromSolr = sparkSession.read.format("solr").options(testoptions).load() | |
testDataFromSolr = testDataFromSolr.withColumnRenamed("test_polarity", "polarity") | |
testDataFromSolr.show | |
val testVectors = testDataFromSolr.rdd.map(row => RowtoLab(row, numFeatures, inputCols, stdTokLowerSchema)).map(x => new LabeledPoint(x.label, standardScaler.transform(x.features))) | |
val scoreAndLabels = testVectors.map(p => { | |
val score = model.predict(p.features) | |
println(">> model predicted: " + score + ", actual: " + p.label) | |
new (Double, Double)(score, p.label) | |
}) | |
val metrics = new BinaryClassificationMetrics(scoreAndLabels) | |
val auROC = metrics.areaUnderROC | |
println("Area under ROC = " + auROC) | |
if (cli.getOptionValue("myHostAndPort") != null) { | |
var metadata = new java.util.HashMap[String, String]() | |
metadata.put("numFeatures", "1000000") | |
metadata.put("featureFields", "tweet_txt") | |
metadata.put("analyzerJson", stdTokLowerSchema) | |
metadata.put("normalizer", "Y") | |
metadata.put("standardscaler", "Y") | |
metadata.put("mean", standardScaler.mean.toString) | |
metadata.put("std", standardScaler.std.toString) | |
} | |
else { | |
model.save(sparkSession.sparkContext, cli.getOptionValue("modelOutput", "mllib-svm-sentiment")) | |
} | |
return 0 | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment