Skip to content

Instantly share code, notes, and snippets.

@Orbifold
Last active August 11, 2019 09:24
Show Gist options
  • Save Orbifold/05ffdd3b49d561ed76fda57981e7833c to your computer and use it in GitHub Desktop.
Save Orbifold/05ffdd3b49d561ed76fda57981e7833c to your computer and use it in GitHub Desktop.
Most basic example of Using MLlib on Spark.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Titanic with MLLib\n",
"\n",
"Pretty much the most straightforward way to make things happen with MLLib on Spark.\n",
"Note the pipeline similar to sklearn.\n",
"\n"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Explicitly setting the OS environment variable because you get a Java '*Java gateway process exited before sending its port number*' on Mac otherwise."
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"\n",
"import findspark, os\n",
"os.environ[\"JAVA_HOME\"]=\"/Library/Java/JavaVirtualMachines/jdk1.8.0_202.jdk/Contents/Home\"\n",
"\n",
"print(findspark.find())\n",
"findspark.init()\n"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/usr/local/Cellar/apache-spark/2.4.3/libexec/\n"
]
}
],
"execution_count": 1,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "markdown",
"source": [
"Get the titanic data from [Kaggle](https://www.kaggle.com/c/titanic/data):"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"from pyspark.sql import SparkSession\n",
"from pyspark.ml.classification import LogisticRegression\n",
"\n",
"spark = SparkSession.builder.appName('titanic_logreg').getOrCreate()\n",
"df = spark.read.csv('train.csv', inferSchema = True, header = True)\n",
"df.show(3)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
"|PassengerId|Survived|Pclass| Name| Sex| Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|\n",
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
"| 1| 0| 3|Braund, Mr. Owen ...| male|22.0| 1| 0| A/5 21171| 7.25| null| S|\n",
"| 2| 1| 1|Cumings, Mrs. Joh...|female|38.0| 1| 0| PC 17599|71.2833| C85| C|\n",
"| 3| 1| 3|Heikkinen, Miss. ...|female|26.0| 0| 0|STON/O2. 3101282| 7.925| null| S|\n",
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
"only showing top 3 rows\n",
"\n"
]
}
],
"execution_count": 2,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"The schema looks like this:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"df.printSchema()"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"root\n",
" |-- PassengerId: integer (nullable = true)\n",
" |-- Survived: integer (nullable = true)\n",
" |-- Pclass: integer (nullable = true)\n",
" |-- Name: string (nullable = true)\n",
" |-- Sex: string (nullable = true)\n",
" |-- Age: double (nullable = true)\n",
" |-- SibSp: integer (nullable = true)\n",
" |-- Parch: integer (nullable = true)\n",
" |-- Ticket: string (nullable = true)\n",
" |-- Fare: double (nullable = true)\n",
" |-- Cabin: string (nullable = true)\n",
" |-- Embarked: string (nullable = true)\n",
"\n"
]
}
],
"execution_count": 3,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"The columns are:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"df.columns"
],
"outputs": [
{
"output_type": "execute_result",
"execution_count": 4,
"data": {
"text/plain": [
"['PassengerId',\n",
" 'Survived',\n",
" 'Pclass',\n",
" 'Name',\n",
" 'Sex',\n",
" 'Age',\n",
" 'SibSp',\n",
" 'Parch',\n",
" 'Ticket',\n",
" 'Fare',\n",
" 'Cabin',\n",
" 'Embarked']"
]
},
"metadata": {}
}
],
"execution_count": 4,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Obviously you don't want to include the id and the label:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"my_col = df.select(['Survived','Pclass','Sex','Age','SibSp','Parch','Fare','Embarked'])"
],
"outputs": [],
"execution_count": null,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "markdown",
"source": [
"Neither should incomplete data be included. A better way would be to replace NA data with averages (for numerical data) and most appearing (for categorical data) but let's keep it simple here:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"final_data = my_col.na.drop()"
],
"outputs": [],
"execution_count": 6,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Some manipulation of the features"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"from pyspark.ml.feature import (VectorAssembler, StringIndexer, VectorIndexer, OneHotEncoder)\n",
"\n",
"gender_indexer = StringIndexer(inputCol = 'Sex', outputCol = 'SexIndex')\n",
"gender_encoder = OneHotEncoder(inputCol='SexIndex', outputCol = 'SexVec')"
],
"outputs": [],
"execution_count": 7,
"metadata": {}
},
{
"cell_type": "code",
"source": [
"embark_indexer = StringIndexer(inputCol = 'Embarked', outputCol = 'EmbarkIndex')\n",
"embark_encoder = OneHotEncoder(inputCol = 'EmbarkIndex', outputCol = 'EmbarkVec')"
],
"outputs": [],
"execution_count": 8,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"The assembler is the MLLib way to get all feature columns in one vector:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"assembler = VectorAssembler(inputCols = ['Pclass', 'SexVec', 'Age', 'SibSp', 'Parch', 'Fare', 'EmbarkVec'], outputCol = 'features')"
],
"outputs": [],
"execution_count": 9,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"The pipeline is a recipe from the raw data to the transformed data"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"from pyspark.ml import Pipeline\n",
"\n",
"log_reg = LogisticRegression(featuresCol = 'features', labelCol = 'Survived')"
],
"outputs": [],
"execution_count": 10,
"metadata": {}
},
{
"cell_type": "code",
"source": [
"pipeline = Pipeline(stages = [gender_indexer, embark_indexer, \n",
" gender_encoder, embark_encoder,\n",
" assembler, log_reg])"
],
"outputs": [],
"execution_count": 11,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"From here on all is classic much like sklearn"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"train, test = final_data.randomSplit([0.7, 0.3])"
],
"outputs": [],
"execution_count": 12,
"metadata": {}
},
{
"cell_type": "code",
"source": [
"fit_model = pipeline.fit(train)"
],
"outputs": [],
"execution_count": 13,
"metadata": {}
},
{
"cell_type": "code",
"source": [
"results = fit_model.transform(test)"
],
"outputs": [],
"execution_count": 14,
"metadata": {}
},
{
"cell_type": "code",
"source": [
"results.select('prediction', 'Survived').show(3)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"+----------+--------+\n",
"|prediction|Survived|\n",
"+----------+--------+\n",
"| 1.0| 0|\n",
"| 1.0| 0|\n",
"| 1.0| 0|\n",
"+----------+--------+\n",
"only showing top 3 rows\n",
"\n"
]
}
],
"execution_count": 15,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"To get the area under the curve (AUC) you can use something like this"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
"\n",
"eval = BinaryClassificationEvaluator(rawPredictionCol = 'rawPrediction', labelCol = 'Survived')\n",
"AUC = eval.evaluate(results)\n",
"AUC"
],
"outputs": [
{
"output_type": "execute_result",
"execution_count": 16,
"data": {
"text/plain": [
"0.850359856057577"
]
},
"metadata": {}
}
],
"execution_count": 16,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"which is horrible but high accuracy was not the aim of this example."
],
"metadata": {}
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"language": "python",
"display_name": "Python 3"
},
"language_info": {
"name": "python",
"version": "3.7.2",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kernel_info": {
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment