Last active
July 14, 2017 00:57
-
-
Save sbcd90/91063761a3950348cea6576d6f0ae3a0 to your computer and use it in GitHub Desktop.
A spark app to show how user specific data types(UDTs) can be made generic using Byte array serialize/deserialize & UTF8String.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark.sql | |
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} | |
import org.apache.spark.SparkConf | |
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.unsafe.types.UTF8String | |
@SQLUserDefinedType(udt = classOf[EmbeddedListUDT]) | |
class EmbeddedList(val elements: Array[Any]) extends Serializable { | |
override def hashCode(): Int = { | |
var hashCode = 1 | |
val i = elements.iterator | |
while (i.hasNext) { | |
val obj = i.next() | |
val elemValue = if (obj == null) 0 else obj.hashCode() | |
hashCode = 31 * hashCode + elemValue | |
} | |
hashCode | |
} | |
override def equals(other: scala.Any): Boolean = other match { | |
case that: EmbeddedList => that.elements.sameElements(this.elements) | |
case _ => false | |
} | |
override def toString: String = elements.mkString(", ") | |
} | |
class EmbeddedListUDT extends UserDefinedType[EmbeddedList] { | |
override def sqlType: DataType = ArrayType(StringType) | |
override def serialize(obj: EmbeddedList): Any = { | |
new GenericArrayData(obj.elements.map{elem => | |
val out = new ByteArrayOutputStream() | |
val os = new ObjectOutputStream(out) | |
os.writeObject(elem) | |
UTF8String.fromBytes(out.toByteArray) | |
}) | |
} | |
override def deserialize(datum: Any): EmbeddedList = { | |
datum match { | |
case values: ArrayData => | |
new EmbeddedList(values.toArray[UTF8String](StringType).map{ elem => | |
val in = new ByteArrayInputStream(elem.getBytes) | |
val is = new ObjectInputStream(in) | |
is.readObject() | |
}) | |
case other => sys.error(s"Cannot deserialize $other") | |
} | |
} | |
override def userClass: Class[EmbeddedList] = classOf[EmbeddedList] | |
private[spark] override def asNullable = this | |
} | |
object EmbeddedListTestApp extends App { | |
val conf = new SparkConf().setAppName("TestApp29").setMaster("local[*]") | |
val spark = SparkSession.builder().config(conf).getOrCreate() | |
val schema = StructType(Array(StructField("id", new EmbeddedListUDT, false))) | |
val df = spark.sqlContext.createDataFrame( | |
spark.sparkContext.parallelize(List(Row(new EmbeddedList(Array(1, 2))), | |
Row(new EmbeddedList(Array(2, 3))))), schema) | |
df.show() | |
df.printSchema() | |
df.filter(row => { | |
row.getAs[EmbeddedList]("id").elements.apply(0) == 1 | |
}).show() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment