Last active
March 25, 2016 16:50
-
-
Save jewer/79118ea5688c74c41015 to your computer and use it in GitHub Desktop.
logistic regression in spark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import org.apache.spark.ml.feature._ | |
| import org.apache.spark.ml.classification._ | |
| import org.apache.spark.mllib.regression.LabeledPoint | |
| import org.apache.spark.mllib.linalg.Vector | |
| import org.apache.spark.sql._ | |
| import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel} | |
| import org.apache.spark.ml.feature.VectorAssembler | |
| //helper function | |
| def load(path: String, sqlContext: SQLContext): DataFrame = { | |
| var data = sqlContext.read.format("com.databricks.spark.csv") | |
| .option("header", "true") | |
| .option("inferSchema", "true") | |
| .load(path) | |
| return data | |
| } | |
| val train = load("gs://dj-data-science-datasets/segmentation/training/*", sqlContext) | |
| //labels should be doubles | |
| val toDouble = sqlContext.udf.register("toDouble", ((n: Int) => { n.toDouble })) | |
| var training = train.drop("UU_ID").drop("post_evar3") | |
| training = training.withColumn("SEG", toDouble(training("SEG"))) | |
| //make a vector out of all the features you care about | |
| var vectorizer = new VectorAssembler().setInputCols(training.columns.tail).setOutputCol("features") | |
| var training2 = vectorizer.transform(training).select("features", training.columns: _*) | |
| //make an RDD out of LabeledPoint (i.e. tuple of label and vector of features | |
| val labeled = training2.select("SEG", "features").rdd | |
| .filter(x => x(0).asInstanceOf[Double] <= 8) | |
| //hold out 20% for testing model | |
| var splits = labeled.randomSplit(Array(0.8, 0.2), seed=11L) | |
| var train = splits(0).cache() | |
| var test = splits(1).cache() | |
| //fit the model | |
| val model = new LogisticRegressionWithLBFGS().setNumClasses(8).run(train) | |
| //check for errors | |
| val labelAndPreds = test.map(point => { | |
| val prediction = model.predict(point.features) | |
| (point.label, prediction) | |
| }) | |
| val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / test.count | |
| //fit model to entire set | |
| val model = var new LogisticRegressionWithLBFGS().setNumClasses(8).run(labeled) | |
| scored.take(10).map(x => { | |
| val r = x.asInstanceOf[Array[org.apache.spark.sql.Row]] | |
| r(0)(1).asInstanceOf[Vector] | |
| }) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment