Created
August 20, 2016 19:34
-
-
Save JRuumis/5302d62fd9ae9519e45dba2d3e6b023d to your computer and use it in GitHub Desktop.
every night in my dreams...
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import org.apache.log4j.{Level, Logger} | |
| import org.apache.spark.ml.Pipeline | |
| import org.apache.spark.ml.classification.RandomForestClassifier | |
| import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler, VectorIndexer} | |
| import org.apache.spark.sql.SparkSession | |
| /** | |
| * Created by Janis Rumnieks on 19/08/2016. | |
| */ | |
| object Titanic { | |
| def main(args: Array[String]): Unit = { | |
| Logger.getLogger("org").setLevel(Level.ERROR) | |
| Logger.getLogger("akka").setLevel(Level.ERROR) | |
| val trainDataFile = """C:\Developer\Kaggle\Titanic\train.csv""" | |
| val testDataFile = """C:\Developer\Kaggle\Titanic\test.csv""" | |
| val spark = SparkSession | |
| .builder() | |
| .appName("Titanic - every night in my dreams...") | |
| .master("local[*]") | |
| .config("spark.sql.warehouse.dir", ".") | |
| .getOrCreate() | |
| //PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked | |
| //val trainDataFrame = spark.createDataFrame(trainContent.tail.toSeq).toDF("Survived", "Pclass", "Name", "Sex", "Age", "SibSp", "Parch", "Ticket", "Fare", "Cabin", "Embarked") | |
| //val testDataFrame = spark.createDataFrame(testContent.tail.toSeq).toDF("Survived", "Pclass", "Name", "Sex", "Age", "SibSp", "Parch", "Ticket", "Fare", "Cabin", "Embarked") | |
| val trainDataFrame = spark.read.option("header",true).csv(trainDataFile).select("PassengerId","Survived","Pclass","Name","Sex","Age","SibSp","Parch","Ticket","Fare","Cabin","Embarked") | |
| trainDataFrame.show(10,false) | |
| val testDataFrame1 = spark.read.option("header",true).csv(testDataFile) | |
| testDataFrame1.show(10,false) | |
| val testDataFrame = testDataFrame1 | |
| .withColumn("Survived", testDataFrame1("PassengerId").cast("Int")-testDataFrame1("PassengerId").cast("Int") ) | |
| .select("PassengerId","Survived","Pclass","Name","Sex","Age","SibSp","Parch","Ticket","Fare","Cabin","Embarked") | |
| testDataFrame.show(10,false) | |
| val indexPclass = new StringIndexer() | |
| .setInputCol("Pclass") | |
| .setOutputCol("PclassIndexed") | |
| val indexSex = new StringIndexer() | |
| .setInputCol("Sex") | |
| .setOutputCol("SexIndexed") | |
| val indexAge = new StringIndexer() | |
| .setInputCol("Age") | |
| .setOutputCol("AgeIndexed") | |
| val indexSibSp = new StringIndexer() | |
| .setInputCol("SibSp") | |
| .setOutputCol("SibSpIndexed") | |
| val indexParch = new StringIndexer() | |
| .setInputCol("Parch") | |
| .setOutputCol("ParchIndexed") | |
| val indexFare = new StringIndexer() | |
| .setInputCol("Fare") | |
| .setOutputCol("FareIndexed") | |
| val indexCabin = new StringIndexer() | |
| .setInputCol("Cabin") | |
| .setOutputCol("CabinIndexed") | |
| val indexEmbarked = new StringIndexer() | |
| .setInputCol("Embarked") | |
| .setOutputCol("EmbarkedIndexed") | |
| val indexSurvived = new StringIndexer() | |
| .setInputCol("Survived") | |
| .setOutputCol("SurvivedIndexed") | |
| //.fit(trainDataFrame) | |
| val featureIndexerPipeline = new Pipeline().setStages(Array(indexPclass,indexSex,indexAge,indexSibSp,indexParch,indexFare,indexCabin,indexEmbarked,indexSurvived)) | |
| /* | |
| val indexSurvived = new StringIndexer() | |
| .setInputCol("Survived") | |
| .setOutputCol("SurvivedIndexed") | |
| .fit(trainDataFrame) | |
| */ | |
| val trainDataFrameIndexModel = featureIndexerPipeline.fit(trainDataFrame) | |
| val trainDataFrameIndexed = trainDataFrameIndexModel.transform(trainDataFrame) | |
| //trainDataFrameIndexed.show(20,false) | |
| val testDataFrameIndexModel = featureIndexerPipeline.fit(testDataFrame) | |
| val testDataFrameIndexed = testDataFrameIndexModel.transform(testDataFrame) | |
| //testDataFrameIndexed.show(20,false) | |
| val titanicFeatureAssembler = new VectorAssembler() | |
| .setInputCols(Array("PclassIndexed", "SexIndexed", "AgeIndexed", "SibSpIndexed", /*"ParchIndexed",*/ "FareIndexed", "CabinIndexed", "EmbarkedIndexed")) | |
| //.setInputCols(Array("PclassIndexed", "SexIndexed", "EmbarkedIndexed")) | |
| .setOutputCol("features") | |
| val trainDataFrameVectored = titanicFeatureAssembler.transform(trainDataFrameIndexed) | |
| val testDataFrameVectored = titanicFeatureAssembler.transform(testDataFrameIndexed) | |
| //println("TRAIN VECTORED") | |
| //trainDataFrameVectored.show(138,false) | |
| //println("TEST VECTORED VECTORED") | |
| //testDataFrameVectored.show(138,false) | |
| val titanicFeatureIndexer = new VectorIndexer() | |
| .setInputCol("features") | |
| .setOutputCol("indexedFeatures") | |
| .setMaxCategories(10) | |
| .fit(trainDataFrameVectored) | |
| val titanicRandomForestClassifier = new RandomForestClassifier() | |
| .setLabelCol("SurvivedIndexed") | |
| .setFeaturesCol("indexedFeatures") | |
| .setNumTrees(10) | |
| .setMaxBins(350) | |
| val titanicPipeline = new Pipeline() | |
| .setStages(Array(titanicFeatureIndexer, titanicRandomForestClassifier)) | |
| val titanicModel = titanicPipeline.fit(trainDataFrameVectored) | |
| val titanicPredictions = titanicModel.transform(testDataFrameVectored) | |
| println(s"VOLUMES: trainDataFrame: ${trainDataFrame.count()} testDataFrame: ${testDataFrame.count()} testDataFrameIndexed: ${testDataFrameIndexed.count()} testDataFrameVectored: ${testDataFrameVectored.count()}, titanicPredictions: ${titanicPredictions.count()}") | |
| //titanicPredictions.show(138,false) // between 137 and 140 fail ======= 138 | |
| titanicPredictions.select("PassengerId","prediction").show(20) | |
| val survivors = titanicPredictions.select("PassengerId","prediction").collect() | |
| //val survivors = titanicPredictions.collect() | |
| import java.io._ | |
| val pw = new PrintWriter(new File("""C:\Developer\Kaggle\Titanic\janis_output_10.csv""")) | |
| pw.write("PassengerId,Survived\n") | |
| survivors foreach ( row => pw.write(s"${row.getString(0)},${row.getDouble(1).toInt}\n") ) | |
| pw.close | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment