Created
June 15, 2015 21:56
-
-
Save aazout/49b79d205dd2c04f14c3 to your computer and use it in GitHub Desktop.
Estimator implementation using spark.ml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.aol.advertisting.execution.ml | |
import org.apache.spark.ml.classification._ | |
import org.apache.spark.ml.param._ | |
import org.apache.spark.sql.{DataFrame, Row} | |
import org.apache.spark.mllib.linalg.{Vector, VectorUDT} | |
import org.apache.spark.sql.types.{DataType, StructType} | |
import java.io._ | |
import scala.sys.process._ | |
/* Extend */ | |
trait ExternalBatchClassifierParams extends Params | |
class ExternalBatchClassifier(override val uid: String) | |
extends Classifier[Vector, ExternalBatchClassifier, ExternalBatchClassificationModel] | |
with ExternalBatchClassifierParams { | |
def printToFile(f: java.io.File)(op: java.io.PrintWriter => Unit) { | |
val p = new java.io.PrintWriter(f) | |
try { op(p) } finally { p.close() } | |
} | |
override def train(dataset: DataFrame): ExternalBatchClassificationModel = { | |
dataset.select("campaignId", "trainingLine").show() | |
val rdd = dataset.select("campaignId", "trainingLine").rdd | |
val x = (rdd groupBy(x => x.getInt(0)) map (x => { | |
val f = File.createTempFile(x._1.toString(), "model") | |
f.deleteOnExit | |
printToFile(f) { p => x._2 map (y => p.println(y.getString(1))) } | |
//The model file, this can store to DB, etc | |
val modelFileName = x._1.toString() + "weights.model" //working dir for now | |
val output = (Process("vw -d " + f.getPath + " --loss_function=logistic --l2=0.00000001 --l1=0.00000001 -b 20 --invert_hash " + modelFileName)).!! | |
println("VW Output: " + output) | |
modelFileName | |
})).collect() | |
new ExternalBatchClassificationModel(uid, 2).setParent(this) | |
} | |
override def validateAndTransformSchema( | |
schema: StructType, | |
fitting: Boolean, | |
featuresDataType: DataType): StructType = { | |
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector | |
schema | |
} | |
} | |
class ExternalBatchClassificationModel( | |
override val uid: String, | |
override val numClasses: Int) | |
extends ClassificationModel[Vector, ExternalBatchClassificationModel] | |
with ExternalBatchClassifierParams { | |
override def predictRaw(features: Vector): Vector = ??? | |
override def validateAndTransformSchema( | |
schema: StructType, | |
fitting: Boolean, | |
featuresDataType: DataType): StructType = { | |
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector | |
schema | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment