Last active
June 18, 2018 12:53
-
-
Save sadikovi/8996b0bf56121cbd10c896f670bb834b to your computer and use it in GitHub Desktop.
Example of StreamSinkProvider for structured streaming with custom query execution
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark.sql.sadikovi | |
import java.io.{ObjectInputStream, ObjectOutputStream} | |
import java.util.UUID | |
import org.apache.hadoop.conf.Configuration | |
import org.apache.hadoop.fs.Path | |
import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} | |
import org.apache.spark.internal.io._ | |
import org.apache.spark.sql._ | |
import org.apache.spark.sql.catalyst._ | |
import org.apache.spark.sql.catalyst.encoders._ | |
import org.apache.spark.sql.catalyst.expressions._ | |
import org.apache.spark.sql.catalyst.optimizer._ | |
import org.apache.spark.sql.catalyst.plans._ | |
import org.apache.spark.sql.catalyst.plans.logical._ | |
import org.apache.spark.sql.catalyst.util._ | |
import org.apache.spark.sql.execution._ | |
import org.apache.spark.sql.execution.datasources._ | |
import org.apache.spark.sql.execution.datasources.parquet._ | |
import org.apache.spark.sql.execution.streaming._ | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.sources._ | |
import org.apache.spark.sql.streaming._ | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.util._ | |
// == streaming == | |
class DefaultSource extends StreamSinkProvider { | |
def resetLazy(obj: Any, clazz: Class[_]): Unit = { | |
val bitmap = clazz.getDeclaredField("bitmap$0") | |
bitmap.setAccessible(true) | |
if (bitmap.getType == classOf[Boolean]) { | |
bitmap.setBoolean(obj, false) | |
} else if (bitmap.getType == classOf[Byte]) { | |
bitmap.setByte(obj, 0) | |
} else if (bitmap.getType == classOf[Short]) { | |
bitmap.setShort(obj, 0) | |
} else { | |
// assume it is integer | |
bitmap.setInt(obj, 0) | |
} | |
} | |
def updateQueryExecution(qe: QueryExecution, modified: LogicalPlan): Unit = { | |
resetLazy(qe, qe.getClass) | |
val parentClass = qe.getClass.getSuperclass | |
resetLazy(qe, parentClass) | |
// update plan | |
val planField = parentClass.getDeclaredField("logical") | |
planField.setAccessible(true) | |
planField.set(qe, modified) | |
} | |
override def createSink( | |
sqlContext: SQLContext, | |
parameters: Map[String, String], | |
partitionColumns: Seq[String], | |
outputMode: OutputMode): Sink = { | |
new Sink() { | |
override def addBatch(batchId: Long, data: org.apache.spark.sql.DataFrame): Unit = { | |
val original = data.queryExecution.logical | |
println(data.queryExecution) | |
println(s""" | |
| Plan: ${original.getClass} | |
| Output: ${original.output.toList} | |
| Expressions: ${original.expressions.toList} | |
""".stripMargin) | |
val modified = { | |
val timezone = DateTimeUtils.defaultTimeZone().getID() | |
val exprs = original.output.map { expr => | |
expr match { | |
case date if date.dataType == DateType => | |
Alias(Cast(date, TimestampType), date.name)() | |
case ts if ts.dataType == TimestampType => | |
Alias(FromUTCTimestamp(ts, Literal(timezone)), ts.name)() | |
case other => other | |
} | |
} | |
CollapseProject.apply(Project(exprs, original)) | |
} | |
// Update query execution by resetting the fields | |
val qe = data.queryExecution | |
updateQueryExecution(qe, modified) | |
qe.analyzed.foreachUp { | |
case p if p.isStreaming => | |
println(s"[modified] => Streaming source: $p") | |
null | |
case _ => | |
} | |
println(s"[modified] isStreaming: ${qe.analyzed.isStreaming}") | |
println(s"[modified] class: ${qe.getClass}") | |
println(qe.analyzed.schema) | |
println(qe) | |
val hadoopConf = data.sparkSession.sessionState.newHadoopConf() | |
val committer = FileCommitProtocol.instantiate( | |
className = classOf[SimpleCommitProtocol].getCanonicalName, | |
jobId = batchId.toString, | |
outputPath = parameters("path")) | |
FileFormatWriter.write( | |
sparkSession = data.sparkSession, | |
plan = qe.executedPlan, | |
fileFormat = new ParquetFileFormat(), | |
committer = committer, | |
outputSpec = FileFormatWriter.OutputSpec(parameters("path"), Map.empty, qe.analyzed.output), | |
hadoopConf = hadoopConf, | |
partitionColumns = Nil, | |
bucketSpec = None, | |
statsTrackers = Nil, | |
options = parameters | |
) | |
} | |
} | |
} | |
override def toString(): String = { | |
this.getClass.getCanonicalName | |
} | |
} | |
class SimpleCommitProtocol(jobId: String, path: String) extends FileCommitProtocol with Serializable { | |
override def setupJob(jobContext: JobContext): Unit = { } | |
override def commitJob(jobContext: JobContext, taskCommits: Seq[FileCommitProtocol.TaskCommitMessage]): Unit = { } | |
override def abortJob(jobContext: JobContext): Unit = { } | |
override def setupTask(taskContext: TaskAttemptContext): Unit = { } | |
override def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { | |
val split = taskContext.getTaskAttemptID.getTaskID.getId | |
val uuid = UUID.randomUUID.toString | |
val filename = f"part-$split%05d-$uuid$ext" | |
val file = dir.map { d => | |
new Path(new Path(path, d), filename).toString | |
}.getOrElse { | |
new Path(path, filename).toString | |
} | |
file | |
} | |
override def newTaskTempFileAbsPath(taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { "" } | |
override def commitTask(taskContext: TaskAttemptContext): FileCommitProtocol.TaskCommitMessage = { FileCommitProtocol.EmptyTaskCommitMessage } | |
override def abortTask(taskContext: TaskAttemptContext): Unit = { } | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment