Skip to content

Instantly share code, notes, and snippets.

@aazout
Created June 15, 2015 21:56
Show Gist options
  • Save aazout/49b79d205dd2c04f14c3 to your computer and use it in GitHub Desktop.
Save aazout/49b79d205dd2c04f14c3 to your computer and use it in GitHub Desktop.
Estimator implementation using spark.ml
package com.aol.advertisting.execution.ml
import org.apache.spark.ml.classification._
import org.apache.spark.ml.param._
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.{DataType, StructType}
import java.io._
import scala.sys.process._
/* Extend */
trait ExternalBatchClassifierParams extends Params
class ExternalBatchClassifier(override val uid: String)
extends Classifier[Vector, ExternalBatchClassifier, ExternalBatchClassificationModel]
with ExternalBatchClassifierParams {
def printToFile(f: java.io.File)(op: java.io.PrintWriter => Unit) {
val p = new java.io.PrintWriter(f)
try { op(p) } finally { p.close() }
}
override def train(dataset: DataFrame): ExternalBatchClassificationModel = {
dataset.select("campaignId", "trainingLine").show()
val rdd = dataset.select("campaignId", "trainingLine").rdd
val x = (rdd groupBy(x => x.getInt(0)) map (x => {
val f = File.createTempFile(x._1.toString(), "model")
f.deleteOnExit
printToFile(f) { p => x._2 map (y => p.println(y.getString(1))) }
//The model file, this can store to DB, etc
val modelFileName = x._1.toString() + "weights.model" //working dir for now
val output = (Process("vw -d " + f.getPath + " --loss_function=logistic --l2=0.00000001 --l1=0.00000001 -b 20 --invert_hash " + modelFileName)).!!
println("VW Output: " + output)
modelFileName
})).collect()
new ExternalBatchClassificationModel(uid, 2).setParent(this)
}
override def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
schema
}
}
class ExternalBatchClassificationModel(
override val uid: String,
override val numClasses: Int)
extends ClassificationModel[Vector, ExternalBatchClassificationModel]
with ExternalBatchClassifierParams {
override def predictRaw(features: Vector): Vector = ???
override def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
schema
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment