Last active
August 4, 2021 09:49
-
-
Save GrigorievNick/9adffa7f0b551bdee34118050c5491c2 to your computer and use it in GitHub Desktop.
Merge records in two dataframes by id columns
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.sql.catalyst.encoders.RowEncoder | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.types.StructField | |
import org.apache.spark.sql.types.StructType | |
import org.apache.spark.sql.DataFrame | |
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.SparkSession | |
import org.scalatest.FunSuite | |
import java.sql.Timestamp | |
import java.time.LocalDateTime | |
case class MergedRecord(id: Long, id2: Long, data: String, ts: Timestamp, hdl: String) | |
class MergeDataframe extends FunSuite { | |
private val initialTimestamp = LocalDateTime.now().minusDays(1) | |
implicit val sparkSession: SparkSession = SparkSession.builder() | |
.config("spark.sql.adaptive.enabled", "true") // with out it merge info genera too many files | |
.config("spark.sql.adaptive.coalescePartitions.enabled", "true") | |
.master("local") | |
.getOrCreate() | |
import sparkSession.implicits._ | |
test("merge by key") { | |
val idCol = "id" | |
val idCol2 = "id2" | |
val leftData = (0 until 100).map(id => (id, id + 100, "data1", Timestamp.valueOf(initialTimestamp))) | |
val rightData = (0 until 100).map(id => (id, id + 100, "hdlData")) | |
val leftDf = leftData.toDF(idCol, idCol2, "data", "ts") | |
val rightDf = rightData.toDF(idCol, idCol2, "hdl") | |
val mergedRecords = leftDf | |
.mergeBy(rightDf, List(idCol, idCol2), Some(3)) | |
.sort(idCol) | |
.as[MergedRecord] | |
mergedRecords.show(numRows = 100, truncate = false) | |
val expected = leftData.map(entry => MergedRecord(entry._1, entry._2, entry._3, entry._4, "hdlData")).toArray | |
val result = mergedRecords.collect() | |
result.zip(expected).foreach { case (l, r) => assert(l == r) } | |
} | |
implicit class DataFrameUnionUtils(df: DataFrame) { | |
private val sourceTypeColumn = "type" | |
def schemaByName(colName: String): StructField = df.schema.fields(df.schema.fieldIndex(colName)) | |
def mergeBy(rightDf: DataFrame, keys: Seq[String], numPartitions: Option[Int] = None): DataFrame = { | |
// TODO Implement for case when different number of rows, treat no row as null | |
assert( | |
rightDf.schema.fields.take(keys.size).sameElements(df.schema.fields.take(keys.size)), | |
"key fields must have same type and order" | |
) | |
val keyFields = df.schema.fields.take(keys.size).toList | |
val dfFields = df.schema.fields.drop(keys.size).toList | |
val rightDfFields = rightDf.schema.fields.drop(keys.size).toList | |
val mergedSchema = StructType(keyFields ::: dfFields ::: rightDfFields) | |
def alignSchema(df: DataFrame, missingFields: Seq[StructField]) = | |
missingFields.foldLeft(df)((df, field) => df.withColumn(field.name, lit(null))) | |
.select(mergedSchema.map(f => col(f.name)): _*) | |
val dfWithNullForRight = alignSchema(df, rightDfFields).withColumn(sourceTypeColumn, lit(1)) | |
val rightDfWithNullForDf = alignSchema(rightDf, dfFields).withColumn(sourceTypeColumn, lit(2)) | |
val dfColumnsByIndex = (keyFields ::: dfFields).map(f => dfWithNullForRight.schema.fieldIndex(f.name)) | |
val rightDfByIndex = rightDfFields.map(_.name).map(rightDfWithNullForDf.schema.fieldIndex) | |
dfWithNullForRight | |
.union(rightDfWithNullForDf) | |
.transform(t => | |
numPartitions | |
.map(num => t.repartition(num, keys.map(col): _*)) | |
.getOrElse(t.repartition(keys.map(col): _*)) | |
).sortWithinPartitions((keys.map(col) :+ col(sourceTypeColumn)): _*) | |
.mapPartitions { it => | |
it | |
.grouped(2) | |
.map { case Seq(r, l) => Row.fromSeq(dfColumnsByIndex.map(r.get) ::: rightDfByIndex.map(l.get)) } | |
}(RowEncoder(mergedSchema)) | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
implicit class SliceBySubsequence[T, K](it: Iterator[T]) extends Serializable { | |
/** | |
* @param key – The function that used to extract key from iterator entry. | |
* @return - An iterator returning of subsequences(Iterator) with same key. | |
* | |
* Note: Reuse: After calling this method, one should discard the iterator it was called on, | |
* and use only the iterator that was returned. Using the old iterator is undefined, subject to change, | |
* and may result in changes to the new iterator as well. | |
*/ | |
def sliceBy(key: T => K): Iterator[Iterator[T]] = new AbstractIterator[Iterator[T]] { | |
private var bufferedIt = it.buffered | |
def hasNext: Boolean = bufferedIt.hasNext | |
def next(): Iterator[T] = | |
bufferedIt.headOption match { | |
case Some(hd) => | |
val (subsequence, rest) = bufferedIt.span(r => key(r) == key(hd)) | |
bufferedIt = rest.buffered | |
subsequence | |
case None => | |
Iterator.empty | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment