Skip to content

Instantly share code, notes, and snippets.

@MLnick
Created August 30, 2017 07:28
Show Gist options
  • Save MLnick/4b530c3363190999f3691853988adba8 to your computer and use it in GitHub Desktop.
Save MLnick/4b530c3363190999f3691853988adba8 to your computer and use it in GitHub Desktop.
Ensemble pipeline component in Spark
class Ensemble(val uid: String, models: Seq[RegressionModel[_, _]]) extends Model[RegressionModel[_, _]] {
import org.apache.spark.sql.functions._
def this(models: Seq[Model[_]]) = this(Identifiable.randomUID("ensemble"), models)
override def copy(extra: ParamMap) = ???
override def transform(
dataset: Dataset[_]): DataFrame = {
val predCols = models.map { m =>
val preds = m.transform(dataset)
preds.col(m.getPredictionCol)
}
def avgPreds = udf { row: Row =>
val preds = row.toSeq.map(_.toString.toDouble)
val size = preds.length
preds.foldLeft(0.0)(_ + _) / size
}
dataset.withColumn("ensemble", avgPreds(predCols: _*))
}
override def transformSchema(
schema: StructType): StructType = ???
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment