Skip to content

Instantly share code, notes, and snippets.

@ahoy-jon
Last active February 3, 2020 11:08
Show Gist options
  • Save ahoy-jon/b65754cde98cc48b9b38 to your computer and use it in GitHub Desktop.
Save ahoy-jon/b65754cde98cc48b9b38 to your computer and use it in GitHub Desktop.
DataFrame.cogroup is the new HList.flatMap (UNFORTUNATELY, THIS IS VERY SLOW)
package org.apache.spark.sql.utils
import org.apache.spark.Partitioner
import org.apache.spark.rdd.{CoGroupedRDD, RDD}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection}
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.types.{ArrayType, StructField, StructType}
import org.apache.spark.sql.{SQLContext, DataFrame, Row}
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
object CogroupDF {
case class KeyedDataFrame[T:TypeTag :ClassTag](dataFrame: DataFrame, keyFn: Row => T) {
def cogroupDf(kdfs:KeyedDataFrame[T]*) = {
CogroupDF.cogroupDf[T]((dataFrame, keyFn), kdfs.map(kdf => (kdf.dataFrame, kdf.keyFn)):_*)
}
}
implicit class CogroupDfWrapp(dataFrame: DataFrame) {
def keyBy[T : TypeTag : Class](keyFn:Row => T):KeyedDataFrame[T] = {
KeyedDataFrame(dataFrame, keyFn)
}
}
private def cogroupDf[T: TypeTag: ClassTag](kdf:(DataFrame,Row => T),kdfs:(DataFrame,Row => T)*):DataFrame = {
val schema = StructType(StructField("key", ScalaReflection.schemaFor[T].dataType) :: (kdf :: kdfs.toList).zipWithIndex.map{
case ((df,_), i) => StructField("_" +(i + 1), ArrayType(df.schema), nullable = false)}
)
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
val fst: RDD[(T, Row)] = kdf._1.rdd.keyBy(kdf._2)
val map: RDD[Row] = fst.withScope({
val rdds: List[RDD[(T, Row)]] = kdfs.map({case (df, f) => df.rdd.keyBy(f) }).toList
val d = new CoGroupedRDD[T](fst :: rdds,Partitioner.defaultPartitioner(fst, rdds:_*))
d.mapValues(_.map(_.toSeq).toList).map({
case (a, rst) => Row.fromSeq(a :: rst.toList)
}).map(converter(_).asInstanceOf[Row])
})
val sqlContext: SQLContext = kdf._1.sqlContext
DataFrame(sqlContext, LogicalRDD(schema.toAttributes,map )(sqlContext))
}
}
import org.apache.spark.sql.{Row, DataFrame}
class DefaultSparkEnv {
val conf: SparkConf = new SparkConf().setAppName("Workshop").setMaster("local[*]")
val sc: SparkContext = new SparkContext(conf)
implicit val sqlContext: SQLContext = new SQLContext(sc)
}
object Usage {
def main (args: Array[String]) {
val defaultSparkEnv: DefaultSparkEnv = new DefaultSparkEnv
import defaultSparkEnv._
val yo: DataFrame = sqlContext.read.load("yo.parquet")
val lo: DataFrame = sqlContext.read.load("lo.parquet")
import org.apache.spark.sql.utils.CogroupDF._
val rowToString: Row => String = _.getAs[String]("lifeId")
val yolo: DataFrame = yo.keyBy(rowToString).cogroupDf(lo.keyBy(rowToString))
yolo.write.save("yolo.parquet")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment