Skip to content

Instantly share code, notes, and snippets.

@sadikovi
Last active June 18, 2018 12:53
Show Gist options
  • Save sadikovi/8996b0bf56121cbd10c896f670bb834b to your computer and use it in GitHub Desktop.
Save sadikovi/8996b0bf56121cbd10c896f670bb834b to your computer and use it in GitHub Desktop.
Example of StreamSinkProvider for structured streaming with custom query execution
package org.apache.spark.sql.sadikovi
import java.io.{ObjectInputStream, ObjectOutputStream}
import java.util.UUID
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext}
import org.apache.spark.internal.io._
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet._
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.types._
import org.apache.spark.util._
// == streaming ==
class DefaultSource extends StreamSinkProvider {
def resetLazy(obj: Any, clazz: Class[_]): Unit = {
val bitmap = clazz.getDeclaredField("bitmap$0")
bitmap.setAccessible(true)
if (bitmap.getType == classOf[Boolean]) {
bitmap.setBoolean(obj, false)
} else if (bitmap.getType == classOf[Byte]) {
bitmap.setByte(obj, 0)
} else if (bitmap.getType == classOf[Short]) {
bitmap.setShort(obj, 0)
} else {
// assume it is integer
bitmap.setInt(obj, 0)
}
}
def updateQueryExecution(qe: QueryExecution, modified: LogicalPlan): Unit = {
resetLazy(qe, qe.getClass)
val parentClass = qe.getClass.getSuperclass
resetLazy(qe, parentClass)
// update plan
val planField = parentClass.getDeclaredField("logical")
planField.setAccessible(true)
planField.set(qe, modified)
}
override def createSink(
sqlContext: SQLContext,
parameters: Map[String, String],
partitionColumns: Seq[String],
outputMode: OutputMode): Sink = {
new Sink() {
override def addBatch(batchId: Long, data: org.apache.spark.sql.DataFrame): Unit = {
val original = data.queryExecution.logical
println(data.queryExecution)
println(s"""
| Plan: ${original.getClass}
| Output: ${original.output.toList}
| Expressions: ${original.expressions.toList}
""".stripMargin)
val modified = {
val timezone = DateTimeUtils.defaultTimeZone().getID()
val exprs = original.output.map { expr =>
expr match {
case date if date.dataType == DateType =>
Alias(Cast(date, TimestampType), date.name)()
case ts if ts.dataType == TimestampType =>
Alias(FromUTCTimestamp(ts, Literal(timezone)), ts.name)()
case other => other
}
}
CollapseProject.apply(Project(exprs, original))
}
// Update query execution by resetting the fields
val qe = data.queryExecution
updateQueryExecution(qe, modified)
qe.analyzed.foreachUp {
case p if p.isStreaming =>
println(s"[modified] => Streaming source: $p")
null
case _ =>
}
println(s"[modified] isStreaming: ${qe.analyzed.isStreaming}")
println(s"[modified] class: ${qe.getClass}")
println(qe.analyzed.schema)
println(qe)
val hadoopConf = data.sparkSession.sessionState.newHadoopConf()
val committer = FileCommitProtocol.instantiate(
className = classOf[SimpleCommitProtocol].getCanonicalName,
jobId = batchId.toString,
outputPath = parameters("path"))
FileFormatWriter.write(
sparkSession = data.sparkSession,
plan = qe.executedPlan,
fileFormat = new ParquetFileFormat(),
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(parameters("path"), Map.empty, qe.analyzed.output),
hadoopConf = hadoopConf,
partitionColumns = Nil,
bucketSpec = None,
statsTrackers = Nil,
options = parameters
)
}
}
}
override def toString(): String = {
this.getClass.getCanonicalName
}
}
class SimpleCommitProtocol(jobId: String, path: String) extends FileCommitProtocol with Serializable {
override def setupJob(jobContext: JobContext): Unit = { }
override def commitJob(jobContext: JobContext, taskCommits: Seq[FileCommitProtocol.TaskCommitMessage]): Unit = { }
override def abortJob(jobContext: JobContext): Unit = { }
override def setupTask(taskContext: TaskAttemptContext): Unit = { }
override def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = {
val split = taskContext.getTaskAttemptID.getTaskID.getId
val uuid = UUID.randomUUID.toString
val filename = f"part-$split%05d-$uuid$ext"
val file = dir.map { d =>
new Path(new Path(path, d), filename).toString
}.getOrElse {
new Path(path, filename).toString
}
file
}
override def newTaskTempFileAbsPath(taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { "" }
override def commitTask(taskContext: TaskAttemptContext): FileCommitProtocol.TaskCommitMessage = { FileCommitProtocol.EmptyTaskCommitMessage }
override def abortTask(taskContext: TaskAttemptContext): Unit = { }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment