Skip to content

Instantly share code, notes, and snippets.

@marcovivero
Created July 7, 2015 23:46
Show Gist options
  • Save marcovivero/8b84b7a86b6846af748c to your computer and use it in GitHub Desktop.
Save marcovivero/8b84b7a86b6846af748c to your computer and use it in GitHub Desktop.
def splitDataTest (
sqlContext : SQLContext,
data : DataFrame,
rowCol : String,
colCol : String,
tokenizer : String => Array[String],
idf : Boolean = true,
numFolds : Int
) : Seq[(AssociatedData, RDD[TestObservation])] = {
val modUDF : Int => (UserDefinedFunction, UserDefinedFunction) = {
k => (
functions.udf((s : String) => s.## % numFolds != k),
functions.udf((s : String) => s.## % numFolds == k)
)
}
// Train/ Test
val validSeq : Seq[(AssociatedData, RDD[TestObservation])]= (1 until (numFolds + 1)).map(k => {
val f = modUDF(k)
(
new AssociatedData(sqlContext, data.filter(f._1(data(colCol))), rowCol, colCol, tokenizer, idf),
prepareTest(data.filter(f._2(data(colCol))), rowCol, colCol)
)
})
validSeq
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment