Created
November 10, 2015 04:43
-
-
Save gbraccialli/232b4317f76a555a772d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//spark-shell --packages sramirez:spark-infotheoretic-feature-selection:1.1,sramirez:spark-MDLP-discretization:1.0 | |
import org.apache.spark.SparkContext._ | |
import org.apache.spark.mllib.linalg.Vectors | |
import org.apache.spark.mllib.regression.LabeledPoint | |
import org.apache.spark.mllib.util.MLUtils | |
import org.apache.spark.mllib.feature._ | |
import org.apache.spark.mllib.feature.MDLPDiscretizer | |
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel} | |
import org.apache.spark.mllib.evaluation.MulticlassMetrics | |
import org.apache.spark.mllib.linalg.Vectors | |
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; | |
import org.apache.spark.mllib.optimization.SquaredL2Updater; | |
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; | |
// Load some data in libsvm format | |
val data = MLUtils.loadLibSVMFile(sc, "file:///usr/hdp/2.3.2.0-2950/spark/data/mllib/sample_libsvm_data.txt") | |
// Discretize data in 16 equal bins since ChiSqSelector requires categorical features | |
// Even though features are doubles, the ChiSqSelector treats each unique value as a category | |
//val discretizedData = data.map { lp => | |
// LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => (x / 16).floor } ) ) | |
//} | |
val categoricalFeat: Option[Seq[Int]] = None | |
val nBins = 25 | |
val maxByPart = 10000 | |
println("*** Discretization method: Fayyad discretizer (MDLP)") | |
println("*** Number of bins: " + nBins) | |
// Data must be cached in order to improve the performance | |
val discretizer = MDLPDiscretizer.train(data, // RDD[LabeledPoint] | |
categoricalFeat, // continuous features | |
nBins, // max number of thresholds by feature | |
maxByPart) // max elements per partition | |
discretizer | |
val discrete = data.map(i => LabeledPoint(i.label, discretizer.transform(i.features))) | |
discrete.first() | |
val criterion = new InfoThCriterionFactory("mrmr") | |
val nToSelect = 5 | |
val nPartitions = 100 | |
println("*** FS criterion: " + criterion.getCriterion.toString) | |
println("*** Number of features to select: " + nToSelect) | |
println("*** Number of partitions: " + nPartitions) | |
val featureSelector = InfoThSelector.train(criterion, | |
discrete, // RDD[LabeledPoint] | |
nToSelect, // number of features to select | |
nPartitions) // number of partitions | |
featureSelector | |
val reduced = data.map(i => LabeledPoint(i.label, featureSelector.transform(i.features))) | |
reduced.first() | |
// Split data into training (60%) and test (40%). | |
val splits = reduced.randomSplit(Array(0.6, 0.4), seed = 11L) | |
val training = splits(0).cache() | |
val test = splits(1) | |
// Run training algorithm to build the model | |
//val model = new LogisticRegressionWithLBFGS().setNumClasses(10).run(training) | |
//model | |
//model.weights | |
//model.intercept | |
// Compute raw scores on the test set. | |
//val predictionAndLabels = test.map { case LabeledPoint(label, features) => | |
// val prediction = model.predict(features) | |
// (prediction, label) | |
//} | |
// Get evaluation metrics. | |
//val metrics = new MulticlassMetrics(predictionAndLabels) | |
//val precision = metrics.precision | |
//println("Precision = " + precision) | |
val updater = new SquaredL2Updater() | |
val model = { | |
val algorithm = new LogisticRegressionWithSGD() | |
algorithm.optimizer.setNumIterations(200).setStepSize(1.0).setUpdater(updater).setRegParam(0.1) | |
algorithm.run(training).clearThreshold() | |
} | |
model | |
model.weights | |
model.intercept | |
val predictionAndLabels = test.map { case LabeledPoint(label, features) => | |
val prediction = model.predict(features) | |
(prediction, label) | |
} | |
val metrics = new BinaryClassificationMetrics(predictionAndLabels) | |
metrics.areaUnderPR() | |
metrics.areaUnderROC() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment