Created
August 18, 2016 18:24
-
-
Save JRuumis/5f7600e54eece42dd0c4c1e1543df84e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import org.apache.log4j.{Level, Logger} | |
| import org.apache.spark.ml.classification.RandomForestClassifier | |
| import org.apache.spark.ml.linalg.Vectors | |
| import org.apache.spark.sql.{DataFrame, SparkSession} | |
| /** | |
| * Created by Janis Rumnieks on 15/08/2016. | |
| */ | |
| object DigitRecognizer4 { | |
| def labelAndFeaturesFromCsv(sparkSession: SparkSession, csvPath: String): DataFrame = { | |
| val csvLines = scala.io.Source.fromFile(csvPath).getLines() map (_.split(",") ) | |
| val csvHeader = csvLines.next() | |
| val csvRows = csvLines map (row => (row map (elem => elem.toDouble))) toList | |
| val csvTuplesWithVectors = csvRows map (row => (row.head, Vectors.dense(row.tail) )) | |
| val dataFrame = sparkSession.createDataFrame(csvTuplesWithVectors).toDF("label","features") | |
| dataFrame | |
| } | |
| def justFeaturesFromCsv(sparkSession: SparkSession, csvPath: String): DataFrame = { | |
| val csvLines = scala.io.Source.fromFile(csvPath).getLines() map (_.split(",") ) | |
| val csvHeader = csvLines.next() | |
| val csvRows = csvLines map (row => (row map (elem => elem.toDouble))) toList | |
| val csvTuplesWithVectors = csvRows map (row => (0.0, Vectors.dense(row) )) | |
| val dataFrame = sparkSession.createDataFrame( csvTuplesWithVectors ).toDF("label","features") | |
| dataFrame | |
| } | |
| def main(args: Array[String]): Unit = { | |
| //val trainDataFile = """C:\Developer\Kaggle\DigitRecogniser\train.csv""" | |
| val trainDataFile = """C:\Developer\Kaggle\DigitRecogniser\train_small_3.csv""" | |
| val testDataFile = """C:\Developer\Kaggle\DigitRecogniser\test.csv""" | |
| // disable INFO messages | |
| Logger.getLogger("org").setLevel(Level.ERROR) | |
| Logger.getLogger("akka").setLevel(Level.ERROR) | |
| val sparkSession = SparkSession | |
| .builder() | |
| .appName("Spark SQL - Digit Recognition") | |
| .master("local[*]") | |
| .config("spark.sql.warehouse.dir", ".") | |
| .getOrCreate() | |
| val trainDataFrame: DataFrame = labelAndFeaturesFromCsv(sparkSession, trainDataFile) | |
| val testDataFrame: DataFrame = justFeaturesFromCsv(sparkSession, testDataFile) | |
| val randomForestEstimator = new RandomForestClassifier() | |
| .setLabelCol("label") | |
| .setFeaturesCol("features") | |
| .setNumTrees(10) | |
| val model = randomForestEstimator.fit(trainDataFrame) | |
| val predictions = model.transform(testDataFrame) | |
| predictions.show(20) | |
| val predictionLabels = predictions.select("prediction").collect() map (_.getDouble(0).toInt) | |
| //println ( predictionLabels.take(20) mkString(",") ) | |
| val imageIdAndPredictedLabel = (1 to predictionLabels.length) zip predictionLabels | |
| import java.io._ | |
| val pw = new PrintWriter(new File("""C:\Developer\Kaggle\DigitRecogniser\janis_output.csv""")) | |
| pw.write("ImageId,Label\n") | |
| imageIdAndPredictedLabel foreach ( row => pw.write(s"${row._1},${row._2}\n") ) | |
| pw.close | |
| println(s"Predictions: ${predictions.count()}") | |
| sparkSession.stop() | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment