Last active
September 24, 2019 03:56
-
-
Save rdblue/468cff86ffcdd07dcea55520ab9c267c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* The base physical plan for writing data into data source v2. | |
*/ | |
abstract class V2TableWriteExec( | |
options: Map[String, String], | |
query: SparkPlan) extends SparkPlan { | |
import org.apache.spark.sql.sources.v2.DataSourceV2Implicits._ | |
def partitioning: Seq[PartitionTransform] | |
override def children: Seq[SparkPlan] = Seq(query) | |
override def output: Seq[Attribute] = Nil | |
... | |
@transient lazy val clusteringExpressions: Seq[Expression] = partitioning.flatMap { | |
case identity: Identity => | |
Some(query.output.find(attr => identity.reference == attr.name) | |
.getOrElse(throw new SparkException(s"Missing attribute: ${identity.name}"))) | |
case year: Year => | |
Some(query.output.find(attr => year.reference == attr.name) | |
.map(attr => IcebergYearTransform(attr)) | |
.getOrElse(throw new SparkException(s"Missing attribute: ${year.name}"))) | |
case month: Month => | |
Some(query.output.find(attr => month.reference == attr.name) | |
.map(attr => IcebergMonthTransform(attr)) | |
.getOrElse(throw new SparkException(s"Missing attribute: ${month.name}"))) | |
case date: Date => | |
Some(query.output.find(attr => date.reference == attr.name) | |
.map(attr => IcebergDayTransform(attr)) | |
.getOrElse(throw new SparkException(s"Missing attribute: ${date.name}"))) | |
case hour: DateAndHour => | |
Some(query.output.find(attr => hour.reference == attr.name) | |
.map(attr => IcebergHourTransform(attr)) | |
.getOrElse(throw new SparkException(s"Missing attribute: ${hour.name}"))) | |
case bucket: Bucket if bucket.references.length == 1 => | |
Some(query.output.find(attr => bucket.references.head == attr.name) | |
.map(attr => IcebergBucketTransform(bucket.numBuckets, attr)) | |
.getOrElse(throw new SparkException(s"Missing attribute: ${bucket.name}"))) | |
case _ => | |
None | |
} | |
override def requiredChildDistribution: Seq[Distribution] = { | |
// add a required distribution if the data is not clustered or ordered | |
lazy val requiredDistribution = { | |
val maybeBucketedAttr = clusteringExpressions.collectFirst { | |
case IcebergBucketTransform(_, attr: Attribute) => | |
attr | |
} | |
maybeBucketedAttr match { | |
case Some(bucketedAttr) => | |
OrderedDistribution(orderingExpressions :+ SortOrder(bucketedAttr, Ascending)) | |
case _ => | |
ClusteredDistribution(clusteringExpressions) | |
} | |
} | |
// only override output partitioning if the data is obviously not distributed for the write | |
val distribution = query.outputPartitioning match { | |
case _ if clusteringExpressions.isEmpty => | |
UnspecifiedDistribution | |
case UnknownPartitioning(_) => | |
requiredDistribution | |
case RoundRobinPartitioning(_) => | |
requiredDistribution | |
case _ => | |
UnspecifiedDistribution | |
} | |
distribution :: Nil | |
} | |
private def unwrapAlias(plan: SparkPlan, expr: Expression): Option[Expression] = { | |
plan match { | |
case ProjectExec(exprs, _) => | |
expr match { | |
case attr: Attribute => | |
val alias = exprs.find { | |
case a: Alias if a.exprId == attr.exprId => true | |
case _ => false | |
} | |
alias.map(_.asInstanceOf[Alias].child) | |
case _ => | |
None | |
} | |
case _ => | |
None | |
} | |
} | |
@transient lazy val orderingExpressions: Seq[SortOrder] = { | |
clusteringExpressions.map { expr => | |
// unwrap aliases that may be added to match up column names to the table | |
// for example: event_type#835 AS event_type#2278 | |
val unaliased = unwrapAlias(query, expr) | |
// match the direction of any child ordering because clustering for tasks is what matters | |
val existingOrdering = query.outputOrdering.find { | |
case SortOrder(child, _, _) => | |
expr.semanticEquals(child) || unaliased.exists(_.semanticEquals(child)) | |
case _ => | |
false | |
} | |
existingOrdering.getOrElse(SortOrder(expr, Ascending)) | |
} | |
} | |
override def requiredChildOrdering: Seq[Seq[SortOrder]] = { | |
requiredChildDistribution match { | |
case Seq(OrderedDistribution(order)) => | |
// if this requires an ordered distribution, require the same sort order | |
order :: Nil | |
case _ => | |
// otherwise, request a local ordering to avoid creating too many output files | |
orderingExpressions :: Nil | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi Ryan, here from the code seems like Spark and Iceberg code are mixed together, I'm not sure how do we inject Iceberg code into Spark, do we need to expose an interface in
V2TableWriteExec
?