Skip to content

Instantly share code, notes, and snippets.

@huafengw
Created May 28, 2018 08:34
Show Gist options
  • Save huafengw/7a72e3ff4943100d993b6e2aa9594f3a to your computer and use it in GitHub Desktop.
Save huafengw/7a72e3ff4943100d993b6e2aa9594f3a to your computer and use it in GitHub Desktop.
mleap example
package com.vip.mlp.dag
import ml.combust.bundle.BundleContext
import ml.combust.bundle.dsl.{Model, NodeShape, Value}
import ml.combust.bundle.op.OpModel
import org.apache.spark.ml.bundle.{ParamSpec, SimpleParamSpec, SimpleSparkOp, SparkBundleContext}
import org.apache.spark.ml.feature.IndexToString
import org.apache.spark.sql.SparkSession
import ml.combust.bundle.BundleFile
import ml.combust.mleap.spark.SparkSupport._
import resource._
object Test extends App {
class IndexToStringOp extends SimpleSparkOp[IndexToString] {
override val Model: OpModel[SparkBundleContext, IndexToString] = new OpModel[SparkBundleContext, IndexToString] {
override val klazz: Class[IndexToString] = classOf[IndexToString]
override def opName: String = "index_to_string"
override def store(model: Model, obj: IndexToString)(implicit context: BundleContext[SparkBundleContext]): Model = {
model.withValue("labels", Value.stringList(obj.getLabels.toList))
}
override def load(model: Model)(implicit context: BundleContext[SparkBundleContext]): IndexToString = {
new IndexToString().setLabels(model.value("labels").getStringList.toArray)
}
}
override def sparkInputs(obj: IndexToString): Seq[ParamSpec] = {
Seq("input" -> obj.inputCol)
}
override def sparkOutputs(obj: IndexToString): Seq[SimpleParamSpec] = {
Seq("input" -> obj.outputCol)
}
override def sparkLoad(uid: String, shape: NodeShape, model: IndexToString): IndexToString = {
new IndexToString(uid = uid).setLabels(model.getLabels)
}
}
val spark = SparkSession
.builder()
.master("local")
.getOrCreate()
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
// Load the data stored in LIBSVM format as a DataFrame.
val data = spark.read.format("libsvm").load("/Users/huafengw/workspace/spark/data/mllib/sample_libsvm_data.txt")
// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
val labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(data)
// Automatically identify categorical features, and index them.
val featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4) // features with > 4 distinct values are treated as continuous.
.fit(data)
// Split the data into training and test sets (30% held out for testing).
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
// Train a DecisionTree model.
val dt = new DecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
// Convert indexed labels back to original labels.
val labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels)
// Chain indexers and tree in a Pipeline.
val pipeline = new Pipeline()
.setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))
// Train model. This also runs the indexers.
val model = pipeline.fit(trainingData)
val test = model.transform(testData)
test.show()
val context = SparkBundleContext.defaultContext.withDataset(test)
context.bundleRegistry.register(new IndexToStringOp())
for(bundle <- managed(BundleFile("jar:file:/tmp/mleap-examples/simple-json.zip"))) {
model.writeBundle.save(bundle)(context).get
}
val zipBundle = (for(bundle <- managed(BundleFile("jar:file:/tmp/mleap-examples/simple-json.zip"))) yield {
bundle.loadSparkBundle()(context).get
}).opt.get
val result = zipBundle.root.transform(testData)
result.show()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment