Last active
September 7, 2021 15:13
-
-
Save GrigorievNick/97d94f2e503ca3892520ac414bad0c73 to your computer and use it in GitHub Desktop.
Spark Create unique sequential id per spark partition
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.sql.Column | |
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.SparkSession | |
import org.apache.spark.sql.catalyst.InternalRow | |
import org.apache.spark.sql.catalyst.encoders.RowEncoder | |
import org.apache.spark.sql.catalyst.expressions.LeafExpression | |
import org.apache.spark.sql.catalyst.expressions.Stateful | |
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper | |
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator | |
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext | |
import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode | |
import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral | |
import org.apache.spark.sql.expressions.Window | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.types.DataType | |
import org.apache.spark.sql.types.LongType | |
object SparkSQLGeneratePartitionOffset { | |
def main(args: Array[String]): Unit = { | |
implicit val spark: SparkSession = SparkSession.builder().master("local").getOrCreate() | |
val randomDF = spark | |
.range(0, 50, 1, 5) | |
.withColumn("id", rand(50) * 10) | |
spark.sparkContext.setJobDescription("Global sort") | |
randomDF | |
.withColumn("part_id", spark_partition_id()) | |
.withColumn("generated_id", row_number().over(Window.partitionBy("part_id").orderBy("id"))) | |
.show(100) | |
spark.sparkContext.setJobDescription("custom zipWithIndex") | |
val df = randomDF.withColumn("part_id", spark_partition_id()) | |
df | |
.mapPartitions { it => | |
it | |
.zipWithIndex | |
.map { case (r, index) => Row.fromSeq(r.toSeq :+ index.toLong) } | |
}(RowEncoder(df.schema.add("generated_id", LongType))) | |
.show(100) | |
spark.sparkContext.setJobDescription("custom sql function") | |
randomDF | |
.withColumn("part_id", spark_partition_id()) | |
.withColumn("generated_id", new Column(SparkPartitionOffset())) | |
.show(100) | |
Thread.sleep(1000000000) | |
} | |
case class SparkPartitionOffset() extends LeafExpression with Stateful { | |
/** | |
* From org.apache.spark.sql.catalyst.expressions.MonotonicallyIncreasingID | |
* | |
* Record ID within each partition. By being transient, count's value is reset to 0 every time | |
* we serialize and deserialize and initialize it. | |
*/ | |
/** | |
* Record ID within each partition. By being transient, count's value is reset to 0 every time | |
* we serialize and deserialize and initialize it. | |
*/ | |
@transient private[this] var count: Long = _ | |
override protected def initializeInternal(partitionIndex: Int): Unit = count = 0L | |
override def nullable: Boolean = false | |
override def dataType: DataType = LongType | |
override protected def evalInternal(input: InternalRow): Long = { | |
val currentCount = count | |
count += 1 | |
currentCount | |
} | |
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | |
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") | |
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") | |
ev.copy( | |
code = code""" | |
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $countTerm; | |
$countTerm++;""", | |
isNull = FalseLiteral | |
) | |
} | |
override def prettyName: String = "spark_partition_offset" | |
override def sql: String = s"$prettyName()" | |
override def freshCopy(): SparkPartitionOffset = SparkPartitionOffset() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment