Skip to content

Instantly share code, notes, and snippets.

@afsalthaj
Created September 24, 2019 13:32
Show Gist options
  • Save afsalthaj/a3509d451f365075aa9691bdc26995e6 to your computer and use it in GitHub Desktop.
Save afsalthaj/a3509d451f365075aa9691bdc26995e6 to your computer and use it in GitHub Desktop.
/**
* Warning: The custom validator is predominantly focussing on performance improvements,
* and has taken a few significant changes with respect to original cross validator.
* This isn't thoroughly tested, although we made a successful with better performance.
*
* TODO: Not tested thoroughly. Use at your own risk:
*/
class BetterCrossValidator extends CrossValidator {
override def fit(dataset: Dataset[_]): CrossValidatorModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
val sparkSession = dataset.sparkSession
val est = $(estimator)
val eval = $(evaluator)
val epm = $(estimatorParamMaps)
val runtime = new DefaultRuntime {}
// We cache these datasets, and later on checkpointed. Checkpoint calculates twice, hence caching it lazily
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)).map({
case (a, b) => (
sparkSession.createDataFrame(a, schema).repartition(1),
sparkSession.createDataFrame(b, schema).repartition(1)
)
})
// A better concurrent execution that doesn't limit parallisation until the transposing stage.
val betterProgram: ZIO[Blocking, Throwable, List[List[Double]]] =
ZIO.foreachPar(splits){
case (training, validation) => {
ZIO.foreachPar(epm){ ppMap =>
{
for {
training <- effectBlocking(training.checkpoint()) // To break lineage and hencecode generation
validation <- effectBlocking(validation.checkpoint()) // To Stop lineage and hence code generation
model <- effectBlocking[Model[_]](est.fit(training, ppMap).asInstanceOf[Model[_]])
metric <- effectBlocking(eval.evaluate(model.transform(validation, ppMap)))
} yield metric
}
}
}
}
val values: Seq[List[Double]] = runtime.unsafeRun(betterProgram)
// The list of average of each parameter combinations; hence size of this is parameter combinations.
val metrics = values.transpose.map(_.sum / $(numFolds))
val (_, bestIndex) =
if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1)
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
copyValues(new CrossValidatorModel(uid, bestModel, metrics.toArray).setParent(this))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment