Created
September 24, 2019 13:32
-
-
Save afsalthaj/a3509d451f365075aa9691bdc26995e6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Warning: The custom validator is predominantly focussing on performance improvements, | |
* and has taken a few significant changes with respect to original cross validator. | |
* This isn't thoroughly tested, although we made a successful with better performance. | |
* | |
* TODO: Not tested thoroughly. Use at your own risk: | |
*/ | |
class BetterCrossValidator extends CrossValidator { | |
override def fit(dataset: Dataset[_]): CrossValidatorModel = { | |
val schema = dataset.schema | |
transformSchema(schema, logging = true) | |
val sparkSession = dataset.sparkSession | |
val est = $(estimator) | |
val eval = $(evaluator) | |
val epm = $(estimatorParamMaps) | |
val runtime = new DefaultRuntime {} | |
// We cache these datasets, and later on checkpointed. Checkpoint calculates twice, hence caching it lazily | |
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)).map({ | |
case (a, b) => ( | |
sparkSession.createDataFrame(a, schema).repartition(1), | |
sparkSession.createDataFrame(b, schema).repartition(1) | |
) | |
}) | |
// A better concurrent execution that doesn't limit parallisation until the transposing stage. | |
val betterProgram: ZIO[Blocking, Throwable, List[List[Double]]] = | |
ZIO.foreachPar(splits){ | |
case (training, validation) => { | |
ZIO.foreachPar(epm){ ppMap => | |
{ | |
for { | |
training <- effectBlocking(training.checkpoint()) // To break lineage and hencecode generation | |
validation <- effectBlocking(validation.checkpoint()) // To Stop lineage and hence code generation | |
model <- effectBlocking[Model[_]](est.fit(training, ppMap).asInstanceOf[Model[_]]) | |
metric <- effectBlocking(eval.evaluate(model.transform(validation, ppMap))) | |
} yield metric | |
} | |
} | |
} | |
} | |
val values: Seq[List[Double]] = runtime.unsafeRun(betterProgram) | |
// The list of average of each parameter combinations; hence size of this is parameter combinations. | |
val metrics = values.transpose.map(_.sum / $(numFolds)) | |
val (_, bestIndex) = | |
if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1) | |
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] | |
copyValues(new CrossValidatorModel(uid, bestModel, metrics.toArray).setParent(this)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment