Last active
August 9, 2016 22:24
-
-
Save vlad17/c22bb4e6679c9e65fcc3e93a92bd3c30 to your computer and use it in GitHub Desktop.
[SPARK-16718] benchmark for million song dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// See gbm.R for context | |
// run with options: | |
// spark-shell --driver-memory 20G --executor-memory 4G --driver-java-options="-Xss500M" -i gbt.spark | |
import org.apache.spark.sql.DataFrame | |
import sys.process._ | |
import java.io._ | |
import org.apache.spark.ml.feature.VectorAssembler | |
import org.apache.spark.ml.classification.GBTClassifier | |
import org.apache.spark.ml.evaluation._ | |
// Download | |
val csvLoc = "/tmp/YearPredictionMSD.txt" | |
val fileLoc = "http://archive.ics.uci.edu/ml/machine-learning-databases/00203/YearPredictionMSD.txt.zip" | |
if (! new File(csvLoc).exists) { | |
val _1 = (s"wget -q -O $csvLoc.zip $fileLoc") ! | |
val _2 = (s"unzip -o -qq $csvLoc -d /tmp") ! | |
} | |
// Extract same way R example does, convert to the binary year-guess problem | |
val raw = spark.read.option("inferSchema", true).option("header", false).csv(csvLoc) | |
val firstCol = raw.columns.head | |
val binary = raw.selectExpr("*", s"cast($firstCol < 2002 as double) as label").drop(firstCol) | |
// Vectorize | |
val explanatory = binary.columns.filterNot(_ == "label") | |
val clean = new VectorAssembler().setInputCols(explanatory).setOutputCol("features").transform(binary) | |
// Get training/test | |
val indexed = clean.rdd.zipWithIndex | |
val cutoff = 463715 | |
def filterIndex(cond: Long => Boolean): DataFrame = { | |
val rdd = indexed.filter(x => cond(x._2)).map(_._1) | |
spark.createDataFrame(rdd, clean.schema) | |
} | |
val (train, test) = (filterIndex(_ < cutoff), filterIndex(_ >= cutoff)) | |
// Estimators | |
val ntrees = 700 | |
val shrinkage = 0.001 | |
val varianceBased = new GBTClassifier().setMaxIter(ntrees).setSubsamplingRate(0.75).setMaxDepth(3).setLossType("bernoulli").setMinInstancesPerNode(10).setStepSize(shrinkage).setImpurity("variance").setLabelCol("label").setSeed(123) | |
val lossBased = varianceBased.setImpurity("loss-based") | |
train.cache() | |
val start = System.nanoTime | |
val varianceModel = varianceBased.fit(train) | |
val varianceTime = System.nanoTime - start | |
val start = System.nanoTime | |
val lossModel = lossBased.fit(train) | |
val lossTime = System.nanoTime - start | |
train.unpersist(true) | |
test.cache() | |
val predVariance = varianceModel.transform(test) | |
val predLoss = lossModel.transform(test) | |
def evaluate(df: DataFrame) = { | |
val eval = new MulticlassClassificationEvaluator() | |
for (metric <- Seq("f1", "weightedPrecision", "weightedRecall", "accuracy")) { | |
println(s" $metric = ${eval.evaluate(df)}") | |
} | |
} | |
println(s"variance impurity perf (midpoint thresh) seconds ${(varianceTime / 1e9).toLong}") | |
evaluate(predVariance) | |
println(s"loss-based impurity perf (midpoint thresh) seconds ${(lossTime / 1e9).toLong}") | |
evaluate(predLoss) | |
val counts = test.groupBy("label").count().select("count").as[Double].collect() | |
counts.max / counts.sum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties | |
Setting default log level to "WARN". | |
To adjust logging level use sc.setLogLevel(newLevel). | |
16/08/09 14:05:56 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable | |
16/08/09 14:05:56 WARN Utils: Your hostname, vlad-databricks resolves to a loopback address: 127.0.1.1; using 192.168.1.23 instead (on interface enp0s31f6) | |
16/08/09 14:05:56 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address | |
16/08/09 14:05:57 WARN SparkContext: Use an existing SparkContext, some configuration may not take effect. | |
Spark context Web UI available at http://192.168.1.23:4040 | |
Spark context available as 'sc' (master = local[*], app id = local-1470776757418). | |
Spark session available as 'spark'. | |
Loading /home/vlad/Desktop/gbt.spark... | |
import org.apache.spark.sql.DataFrame | |
import sys.process._ | |
import java.io._ | |
import org.apache.spark.ml.feature.VectorAssembler | |
import org.apache.spark.ml.classification.GBTClassifier | |
import org.apache.spark.ml.evaluation._ | |
csvLoc: String = /tmp/YearPredictionMSD.txt | |
fileLoc: String = http://archive.ics.uci.edu/ml/machine-learning-databases/00203/YearPredictionMSD.txt.zip | |
warning: there were two feature warnings; re-run with -feature for details | |
firstCol: String = _c0 | |
binary: org.apache.spark.sql.DataFrame = [_c1: double, _c2: double ... 89 more fields] | |
explanatory: Array[String] = Array(_c1, _c2, _c3, _c4, _c5, _c6, _c7, _c8, _c9, _c10, _c11, _c12, _c13, _c14, _c15, _c16, _c17, _c18, _c19, _c20, _c21, _c22, _c23, _c24, _c25, _c26, _c27, _c28, _c29, _c30, _c31, _c32, _c33, _c34, _c35, _c36, _c37, _c38, _c39, _c40, _c41, _c42, _c43, _c44, _c45, _c46, _c47, _c48, _c49, _c50, _c51, _c52, _c53, _c54, _c55, _c56, _c57, _c58, _c59, _c60, _c61, _c62, _c63, _c64, _c65, _c66, _c67, _c68, _c69, _c70, _c71, _c72, _c73, _c74, _c75, _c76, _c77, _c78, _c79, _c80, _c81, _c82, _c83, _c84, _c85, _c86, _c87, _c88, _c89, _c90) | |
clean: org.apache.spark.sql.DataFrame = [_c1: double, _c2: double ... 90 more fields] | |
cutoff: Int = 463715 | |
filterIndex: (cond: Long => Boolean)org.apache.spark.sql.DataFrame | |
train: org.apache.spark.sql.DataFrame = [_c1: double, _c2: double ... 90 more fields] | |
test: org.apache.spark.sql.DataFrame = [_c1: double, _c2: double ... 90 more fields] | |
ntrees: Int = 700 | |
shrinkage: Double = 0.001 | |
varianceBased: org.apache.spark.ml.classification.GBTClassifier = gbtc_cef98ba41324 | |
lossBased: varianceBased.type = gbtc_cef98ba41324 | |
16/08/09 14:06:12 WARN Utils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.debug.maxToStringFields' in SparkEnv.conf. | |
res1: train.type = [_c1: double, _c2: double ... 90 more fields] | |
start: Long = 67129448801407 | |
16/08/09 14:06:16 WARN Executor: 1 block locks were not released by TID = 22: | |
16/08/09 14:06:16 WARN Executor: 1 block locks were not released by TID = 23: | |
varianceModel: org.apache.spark.ml.classification.GBTClassificationModel = GBTClassificationModel (uid=gbtc_cef98ba41324) with 700 trees | |
varianceTime: Long = 1366514285281 | |
start: Long = 68496127610682 | |
16/08/09 14:28:59 WARN Executor: 1 block locks were not released by TID = 51123: | |
16/08/09 14:28:59 WARN Executor: 1 block locks were not released by TID = 51124: | |
lossTime: Long = 1390297870579 | |
res2: train.type = [_c1: double, _c2: double ... 90 more fields] | |
res3: test.type = [_c1: double, _c2: double ... 90 more fields] | |
predVariance: org.apache.spark.sql.DataFrame = [_c1: double, _c2: double ... 91 more fields] | |
predLoss: org.apache.spark.sql.DataFrame = [_c1: double, _c2: double ... 91 more fields] | |
evaluate: (df: org.apache.spark.sql.DataFrame)Unit | |
variance impurity perf (midpoint thresh) seconds 1366 | |
f1 = 0.6547312025559994 | |
weightedPrecision = 0.6547312025559994 | |
weightedRecall = 0.6547312025559994 | |
accuracy = 0.6547312025559994 | |
loss-based impurity perf (midpoint thresh) seconds 1390 | |
f1 = 0.6547312025559994 | |
weightedPrecision = 0.6547312025559994 | |
weightedRecall = 0.6547312025559994 | |
accuracy = 0.6547312025559994 | |
counts: Array[Double] = Array(26812.0, 24818.0) | |
res8: Double = 0.5193104784040287 | |
Welcome to | |
____ __ | |
/ __/__ ___ _____/ /__ | |
_\ \/ _ \/ _ `/ __/ '_/ | |
/___/ .__/\_,_/_/ /_/\_\ version 2.1.0-SNAPSHOT | |
/_/ | |
Using Scala version 2.11.8 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_91) | |
Type in expressions to have them evaluated. | |
Type :help for more information. | |
scala> :quit |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment