Created
January 24, 2020 19:29
-
-
Save viirya/40325c95678832ec2104f7df0f04538f to your computer and use it in GitHub Desktop.
Snippet for extracting nested column from an input row in Spark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.{ByteArrayOutputStream, File} | |
import java.nio.charset.StandardCharsets | |
import java.sql.{Date, Timestamp} | |
import java.util.UUID | |
import java.util.concurrent.atomic.AtomicLong | |
import scala.util.Random | |
import org.scalatest.Matchers._ | |
import org.apache.spark.SparkException | |
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} | |
import org.apache.spark.sql.AnalysisException | |
import org.apache.spark.sql.catalyst.TableIdentifier | |
import org.apache.spark.sql.catalyst.analysis._ | |
import org.apache.spark.sql.catalyst.encoders.RowEncoder | |
import org.apache.spark.sql.catalyst.expressions._ | |
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation | |
import org.apache.spark.sql.catalyst.plans.logical._ | |
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} | |
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | |
import org.apache.spark.sql.execution.aggregate.HashAggregateExec | |
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.internal.SQLConf | |
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} | |
import org.apache.spark.sql.test.SQLTestData.{DecimalData, NullStrings, TestData2} | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.util.Utils | |
import org.apache.spark.util.random.XORShiftRandom | |
class DataFrameSuite extends QueryTest | |
with SharedSparkSession | |
with AdaptiveSparkPlanHelper { | |
import testImplicits._ | |
def genExtractor( | |
structField: StructField, | |
optChild: Option[Expression]): Expression = structField.dataType match { | |
case StructType(fields) if fields.length == 1 => | |
val nextChild = optChild.map { child => | |
UnresolvedExtractValue(child, Literal(fields(0).name)) | |
}.getOrElse { | |
// The root field. | |
UnresolvedExtractValue( | |
UnresolvedAttribute(structField.name), | |
Literal(fields(0).name)) | |
} | |
genExtractor(fields(0), Some(nextChild)) | |
case StructType(fields) => | |
throw new AnalysisException("Access type should have only one field.") | |
// The leaf field. | |
case _ => | |
optChild.getOrElse(UnresolvedAttribute(structField.name)) | |
} | |
test("test") { | |
val nestedSchema = StructType( | |
StructField("col1", StringType) :: | |
StructField("col2", StringType) :: | |
StructField("col3", IntegerType) :: Nil) | |
val values = Array("value1", "value2", 1) | |
// val nestedRow = new GenericRowWithSchema(values, nestedSchema) | |
val schema = StructType(StructField("topCol", nestedSchema) :: Nil) | |
// val row = new GenericRowWithSchema(Array(nestedRow), schema) | |
val nestedRow = new GenericInternalRow(values) | |
val inputRow = new GenericInternalRow(Array(nestedRow.asInstanceOf[Any])) | |
val accessField = StructField("topCol", | |
StructType(StructField("col1", StringType) :: Nil)) | |
val extractorProjection = ProjectionOverSchema(schema) | |
val extractors = Seq(genExtractor(accessField, None)).map { extractor => | |
val projected = extractor.transform { | |
case extractorProjection(expr) => expr | |
} | |
projected match { | |
case n: NamedExpression => n | |
case other => Alias(other, other.prettyName)() | |
} | |
} | |
val attrs = schema.toAttributes | |
val dummyPlan = Project(extractors, LocalRelation(attrs)) | |
val analyzedPlan = SimpleAnalyzer.execute(dummyPlan) | |
SimpleAnalyzer.checkAnalysis(analyzedPlan) | |
val resolvedExtractors = analyzedPlan match { | |
case Project(projectList, _) => projectList | |
case _ => throw new AnalysisException(s"wrong analyzed plan: $analyzedPlan") | |
} | |
val project = new InterpretedProjection(resolvedExtractors, attrs) | |
val outputRow = project(inputRow) | |
println(outputRow) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment