Last active
July 6, 2019 19:12
-
-
Save erikerlandson/b0e106a4dbaf7f80b4f4f3a21f05f892 to your computer and use it in GitHub Desktop.
Benchmarking Description for Spark UDIA pull request
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark.countSerDe | |
import org.apache.spark.internal.Logging | |
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.catalyst.InternalRow | |
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow | |
import org.apache.spark.sql.catalyst.util._ | |
import org.apache.spark.sql.expressions.MutableAggregationBuffer | |
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction | |
import org.apache.spark.sql.expressions.UserDefinedImperativeAggregator | |
import org.apache.spark.sql.types._ | |
@SQLUserDefinedType(udt = classOf[CountSerDeUDT]) | |
case class CountSerDeSQL(nSer: Int, nDeSer: Int, sum: Double) | |
class CountSerDeUDT extends UserDefinedType[CountSerDeSQL] { | |
def userClass: Class[CountSerDeSQL] = classOf[CountSerDeSQL] | |
override def typeName: String = "count-ser-de" | |
private[spark] override def asNullable: CountSerDeUDT = this | |
def sqlType: DataType = StructType( | |
StructField("nSer", IntegerType, false) :: | |
StructField("nDeSer", IntegerType, false) :: | |
StructField("sum", DoubleType, false) :: | |
Nil) | |
def serialize(sql: CountSerDeSQL): Any = { | |
val row = new GenericInternalRow(3) | |
row.setInt(0, 1 + sql.nSer) | |
row.setInt(1, sql.nDeSer) | |
row.setDouble(2, sql.sum) | |
row | |
} | |
def deserialize(any: Any): CountSerDeSQL = any match { | |
case row: InternalRow if (row.numFields == 3) => | |
CountSerDeSQL(row.getInt(0), 1 + row.getInt(1), row.getDouble(2)) | |
case u => throw new Exception(s"failed to deserialize: $u") | |
} | |
override def equals(obj: Any): Boolean = { | |
obj match { | |
case _: CountSerDeUDT => true | |
case _ => false | |
} | |
} | |
override def hashCode(): Int = classOf[CountSerDeUDT].getName.hashCode() | |
} | |
case object CountSerDeUDT extends CountSerDeUDT | |
case object CountSerDeUDIA extends UserDefinedImperativeAggregator[CountSerDeSQL] { | |
import org.apache.spark.unsafe.Platform | |
def inputSchema: StructType = StructType(StructField("x", DoubleType) :: Nil) | |
def resultType: DataType = CountSerDeUDT | |
def deterministic: Boolean = false | |
def initial: CountSerDeSQL = CountSerDeSQL(0, 0, 0) | |
def update(agg: CountSerDeSQL, input: Row): CountSerDeSQL = | |
agg.copy(sum = agg.sum + input.getDouble(0)) | |
def merge(agg1: CountSerDeSQL, agg2: CountSerDeSQL): CountSerDeSQL = | |
CountSerDeSQL(agg1.nSer + agg2.nSer, agg1.nDeSer + agg2.nDeSer, agg1.sum + agg2.sum) | |
def evaluate(agg: CountSerDeSQL): Any = agg | |
def serialize(agg: CountSerDeSQL): Array[Byte] = { | |
val CountSerDeSQL(ns, nd, s) = agg | |
val byteArray = new Array[Byte](4 + 4 + 8) | |
Platform.putInt(byteArray, Platform.BYTE_ARRAY_OFFSET, ns + 1) | |
Platform.putInt(byteArray, Platform.BYTE_ARRAY_OFFSET + 4, nd) | |
Platform.putDouble(byteArray, Platform.BYTE_ARRAY_OFFSET + 8, s) | |
byteArray | |
} | |
def deserialize(data: Array[Byte]): CountSerDeSQL = { | |
val ns = Platform.getInt(data, Platform.BYTE_ARRAY_OFFSET) | |
val nd = Platform.getInt(data, Platform.BYTE_ARRAY_OFFSET + 4) | |
val s = Platform.getDouble(data, Platform.BYTE_ARRAY_OFFSET + 8) | |
CountSerDeSQL(ns, nd + 1, s) | |
} | |
} | |
case object CountSerDeUDAF extends UserDefinedAggregateFunction { | |
def deterministic: Boolean = false | |
def inputSchema: StructType = StructType(StructField("x", DoubleType) :: Nil) | |
def bufferSchema: StructType = StructType(StructField("count-ser-de", CountSerDeUDT) :: Nil) | |
def dataType: DataType = CountSerDeUDT | |
def initialize(buf: MutableAggregationBuffer): Unit = { | |
buf(0) = CountSerDeSQL(0, 0, 0) | |
} | |
def update(buf: MutableAggregationBuffer, input: Row): Unit = { | |
val sql = buf.getAs[CountSerDeSQL](0) | |
buf(0) = sql.copy(sum = sql.sum + input.getDouble(0)) | |
} | |
def merge(buf1: MutableAggregationBuffer, buf2: Row): Unit = { | |
val sql1 = buf1.getAs[CountSerDeSQL](0) | |
val sql2 = buf2.getAs[CountSerDeSQL](0) | |
buf1(0) = CountSerDeSQL(sql1.nSer + sql2.nSer, sql1.nDeSer + sql2.nDeSer, sql1.sum + sql2.sum) | |
} | |
def evaluate(buf: Row): Any = buf.getAs[CountSerDeSQL](0) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark.tdigest | |
import org.isarnproject.sketches.TDigest | |
import org.isarnproject.sketches.tdmap.TDigestMap | |
import org.apache.spark.internal.Logging | |
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.catalyst.InternalRow | |
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} | |
import org.apache.spark.sql.catalyst.util._ | |
import org.apache.spark.sql.expressions.MutableAggregationBuffer | |
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction | |
import org.apache.spark.sql.expressions.UserDefinedImperativeAggregator | |
import org.apache.spark.sql.types._ | |
@SQLUserDefinedType(udt = classOf[TDigestUDT]) | |
case class TDigestSQL(tdigest: TDigest) | |
class TDigestUDT extends UserDefinedType[TDigestSQL] { | |
def userClass: Class[TDigestSQL] = classOf[TDigestSQL] | |
override def pyUDT: String = "isarnproject.sketches.udt.tdigest.TDigestUDT" | |
override def typeName: String = "tdigest" | |
override def equals(obj: Any): Boolean = { | |
obj match { | |
case _: TDigestUDT => true | |
case _ => false | |
} | |
} | |
override def hashCode(): Int = classOf[TDigestUDT].getName.hashCode() | |
private[spark] override def asNullable: TDigestUDT = this | |
def sqlType: DataType = StructType( | |
StructField("delta", DoubleType, false) :: | |
StructField("maxDiscrete", IntegerType, false) :: | |
StructField("nclusters", IntegerType, false) :: | |
StructField("clustX", ArrayType(DoubleType, false), false) :: | |
StructField("clustM", ArrayType(DoubleType, false), false) :: | |
Nil) | |
def serialize(tdsql: TDigestSQL): Any = serializeTD(tdsql.tdigest) | |
def deserialize(datum: Any): TDigestSQL = TDigestSQL(deserializeTD(datum)) | |
def serializeTD(td: TDigest): InternalRow = { | |
val TDigest(delta, maxDiscrete, nclusters, clusters) = td | |
val row = new GenericInternalRow(5) | |
row.setDouble(0, delta) | |
row.setInt(1, maxDiscrete) | |
row.setInt(2, nclusters) | |
val clustX = clusters.keys.toArray | |
val clustM = clusters.values.toArray | |
row.update(3, UnsafeArrayData.fromPrimitiveArray(clustX)) | |
row.update(4, UnsafeArrayData.fromPrimitiveArray(clustM)) | |
row | |
} | |
def deserializeTD(datum: Any): TDigest = datum match { | |
case row: InternalRow => | |
require(row.numFields == 5, s"expected row length 5, got ${row.numFields}") | |
val delta = row.getDouble(0) | |
val maxDiscrete = row.getInt(1) | |
val nclusters = row.getInt(2) | |
val clustX = row.getArray(3).toDoubleArray() | |
val clustM = row.getArray(4).toDoubleArray() | |
val clusters = clustX.zip(clustM) | |
.foldLeft(TDigestMap.empty) { case (td, e) => td + e } | |
TDigest(delta, maxDiscrete, nclusters, clusters) | |
case u => throw new Exception(s"failed to deserialize: $u") | |
} | |
} | |
case object TDigestUDT extends TDigestUDT | |
case class TDigestUDAF(deltaV: Double, maxDiscreteV: Int) extends | |
UserDefinedAggregateFunction { | |
def deterministic: Boolean = false | |
def inputSchema: StructType = StructType(StructField("x", DoubleType) :: Nil) | |
def bufferSchema: StructType = StructType(StructField("tdigest", TDigestUDT) :: Nil) | |
def dataType: DataType = TDigestUDT | |
def initialize(buf: MutableAggregationBuffer): Unit = { | |
buf(0) = TDigestSQL(TDigest.empty(deltaV, maxDiscreteV)) | |
} | |
def update(buf: MutableAggregationBuffer, input: Row): Unit = { | |
if (!input.isNullAt(0)) { | |
buf(0) = TDigestSQL(buf.getAs[TDigestSQL](0).tdigest + input.getDouble(0)) | |
} | |
} | |
def merge(buf1: MutableAggregationBuffer, buf2: Row): Unit = { | |
buf1(0) = TDigestSQL(buf1.getAs[TDigestSQL](0).tdigest ++ buf2.getAs[TDigestSQL](0).tdigest) | |
} | |
def evaluate(buf: Row): Any = buf.getAs[TDigestSQL](0) | |
} | |
case class TDigestUDIA(deltaV: Double, maxDiscreteV: Int) extends | |
UserDefinedImperativeAggregator[TDigest] { | |
def inputSchema: StructType = StructType(StructField("x", DoubleType) :: Nil) | |
def resultType: DataType = TDigestUDT | |
def deterministic: Boolean = false | |
def initial: TDigest = TDigest.empty(deltaV, maxDiscreteV) | |
def update(agg: TDigest, input: Row): TDigest = | |
if (input.isNullAt(0)) agg else agg + input.getDouble(0) | |
def merge(agg1: TDigest, agg2: TDigest): TDigest = agg1 ++ agg2 | |
def evaluate(agg: TDigest): Any = TDigestSQL(agg) | |
import java.io._ | |
// scalastyle:off classforname | |
class ObjectInputStreamWithCustomClassLoader( | |
inputStream: InputStream) extends ObjectInputStream(inputStream) { | |
override def resolveClass(desc: java.io.ObjectStreamClass): Class[_] = { | |
try { Class.forName(desc.getName, false, getClass.getClassLoader) } | |
catch { case ex: ClassNotFoundException => super.resolveClass(desc) } | |
} | |
} | |
// scalastyle:off classforname | |
def serialize(agg: TDigest): Array[Byte] = { | |
val bufout = new ByteArrayOutputStream() | |
val obout = new ObjectOutputStream(bufout) | |
obout.writeObject(agg) | |
bufout.toByteArray | |
} | |
def deserialize(data: Array[Byte]): TDigest = { | |
val bufin = new ByteArrayInputStream(data) | |
val obin = new ObjectInputStreamWithCustomClassLoader(bufin) | |
obin.readObject().asInstanceOf[TDigest] | |
} | |
} | |
object Benchmark { | |
def apply[T](blk: => T): (Double, T) = { | |
val t0 = System.currentTimeMillis | |
val v = blk | |
val t = System.currentTimeMillis | |
((t - t0).toDouble / 1000.0, v) | |
} | |
def sample[T](samples: Int)(blk: => T): Array[(Double, T)] = { | |
Array.fill(samples) { | |
val t0 = System.currentTimeMillis | |
val v = blk | |
val t = System.currentTimeMillis | |
((t - t0).toDouble / 1000.0, v) | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Using Scala version 2.12.8 (OpenJDK 64-Bit Server VM, Java 1.8.0_212) | |
Type in expressions to have them evaluated. | |
Type :help for more information. | |
scala> import scala.util.Random._, org.apache.spark.countSerDe._, org.apache.spark.sql.Row, org.apache.spark.tdigest._ | |
import scala.util.Random._ | |
import org.apache.spark.countSerDe._ | |
import org.apache.spark.sql.Row | |
import org.apache.spark.tdigest._ | |
scala> sc.setLogLevel("ERROR") | |
scala> val udaf = CountSerDeUDAF | |
udaf: org.apache.spark.countSerDe.CountSerDeUDAF.type = CountSerDeUDAF | |
scala> val udia = CountSerDeUDIA | |
udia: org.apache.spark.countSerDe.CountSerDeUDIA.type = CountSerDeUDIA | |
scala> val data = sc.parallelize(Vector.fill(2000000){(nextInt(2), nextGaussian, nextGaussian)}, 5).toDF("cat", "x1", "x2") | |
data: org.apache.spark.sql.DataFrame = [cat: int, x1: double ... 1 more field] | |
scala> val bs = Benchmark.sample(50) { data.agg(udaf($"x1"), udaf($"x2")).first } | |
bs: Array[(Double, org.apache.spark.sql.Row)] = Array((7.986,[CountSerDeSQL(2000012,2000012,-54.30894451561437),CountSerDeSQL(2000012,2000012,208.8748610102166)]), (7.725,[CountSerDeSQL(2000012,2000012,-54.30894451561437),CountSerDeSQL(2000012,2000012,208.8748610102166)]), (8.81,[CountSerDeSQL(2000012,2000012,-54.30894451561437),CountSerDeSQL(2000012,2000012,208.8748610102166)]), (7.62,[CountSerDeSQL(2000012,2000012,-54.30894451561437),CountSerDeSQL(2000012,2000012,208.8748610102166)]), (9.375,[CountSerDeSQL(2000012,2000012,-54.30894451561437),CountSerDeSQL(2000012,2000012,208.8748610102166)]), (8.504,[CountSerDeSQL(2000012,2000012,-54.30894451561437),CountSerDeSQL(2000012,2000012,208.8748610102166)]), (8.336,[CountSerDeSQL(2000012,2000012,-54.30894451561437),C... | |
scala> bs.map(_._1) | |
res0: Array[Double] = Array(7.986, 7.725, 8.81, 7.62, 9.375, 8.504, 8.336, 8.627, 9.151, 8.124, 9.028, 8.849, 9.401, 9.318, 8.703, 8.815, 8.966, 9.302, 9.111, 8.933, 9.022, 8.893, 9.653, 9.284, 8.205, 9.627, 8.947, 9.152, 9.078, 9.839, 9.08, 9.246, 7.5, 8.279, 8.793, 9.492, 9.099, 9.096, 9.132, 9.206, 8.585, 9.093, 8.946, 8.848, 8.366, 8.795, 9.01, 8.946, 9.011, 6.885) | |
scala> bs.map(_._1).sum / 50 | |
res1: Double = 8.835840000000001 | |
scala> val bs = Benchmark.sample(50) { data.agg(udia($"x1"), udia($"x2")).first } | |
bs: Array[(Double, org.apache.spark.sql.Row)] = Array((7.514,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (5.283,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (6.291,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (5.625,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (6.457,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (4.961,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (6.089,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (5.305,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (6.658,[CountSerDeSQL(6,6,-54... | |
scala> bs.map(_._1) | |
res2: Array[Double] = Array(7.514, 5.283, 6.291, 5.625, 6.457, 4.961, 6.089, 5.305, 6.658, 5.765, 5.521, 6.172, 6.694, 6.169, 5.507, 6.878, 6.485, 5.872, 5.577, 5.875, 6.315, 7.104, 5.829, 5.975, 5.826, 5.349, 6.513, 5.836, 6.129, 6.868, 6.09, 5.857, 7.068, 6.866, 6.203, 6.521, 6.108, 6.664, 5.94, 6.958, 6.163, 6.785, 6.013, 6.493, 6.23, 6.054, 7.12, 5.849, 6.898, 6.173) | |
scala> bs.map(_._1).sum / 50 | |
res3: Double = 6.209899999999999 | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Using Scala version 2.12.8 (OpenJDK 64-Bit Server VM, Java 1.8.0_212) | |
Type in expressions to have them evaluated. | |
Type :help for more information. | |
scala> import scala.util.Random._, org.apache.spark.countSerDe._, org.apache.spark.sql.Row, org.apache.spark.tdigest._ | |
import scala.util.Random._ | |
import org.apache.spark.countSerDe._ | |
import org.apache.spark.sql.Row | |
import org.apache.spark.tdigest._ | |
scala> sc.setLogLevel("ERROR") | |
scala> val data = sc.parallelize(Vector.fill(100000){(nextInt(2), nextGaussian, nextGaussian)}, 5).toDF("cat", "x1", "x2") | |
data: org.apache.spark.sql.DataFrame = [cat: int, x1: double ... 1 more field] | |
scala> val udia = TDigestUDIA(0.5, 0) | |
udia: org.apache.spark.tdigest.TDigestUDIA = TDigestUDIA(0.5,0) | |
scala> val udaf = TDigestUDAF(0.5, 0) | |
udaf: org.apache.spark.tdigest.TDigestUDAF = TDigestUDAF(0.5,0) | |
scala> val bs = Benchmark.sample(10) { data.agg(udaf($"x1")).first } | |
bs: Array[(Double, org.apache.spark.sql.Row)] = Array((15.694,[TDigestSQL(TDigest(0.5,0,133,TDigestMap(-4.421763563584745 -> (1.0, 1.0), -4.254974064796795 -> (1.0, 2.0), -4.232552385036448 -> (1.0, 3.0), -4.090629063164773 -> (0.15027423113120597, 3.150274231131206), -4.026284332714937 -> (1.9991595262149298, 5.149433757346136), -4.011367664781805 -> (0.8505662426538643, 6.0), -3.8819633785079186 -> (2.6052762694868314, 8.605276269486831), -3.855805587944206 -> (2.2542573187962667, 10.859533588283098), -3.788192954654244 -> (2.9647075216215226, 13.82424110990462), -3.7259281425454516 -> (1.1757588900953793, 14.999999999999998), -3.630667958772349 -> (5.498887306391312, 20.49888730639131), -3.5248355152381885 -> (3.6185151659764703, 24.11740247236778), -3.46119... | |
scala> bs.map(_._1) | |
res0: Array[Double] = Array(15.694, 18.268, 18.733, 18.623, 18.717, 18.924, 18.921, 18.939, 18.815, 18.995) | |
scala> bs.map(_._1).sum / 10 | |
res1: Double = 18.462899999999998 | |
scala> val bs = Benchmark.sample(10) { data.agg(udia($"x1")).first } | |
bs: Array[(Double, org.apache.spark.sql.Row)] = Array((0.29,[TDigestSQL(TDigest(0.5,0,133,TDigestMap(-4.421763563584745 -> (1.0, 1.0), -4.254974064796795 -> (1.0, 2.0), -4.232552385036448 -> (1.0, 3.0), -4.090629063164773 -> (0.15027423113120597, 3.150274231131206), -4.026284332714937 -> (1.9991595262149298, 5.149433757346136), -4.011367664781805 -> (0.8505662426538643, 6.0), -3.8819633785079186 -> (2.6052762694868314, 8.605276269486831), -3.855805587944206 -> (2.2542573187962667, 10.859533588283098), -3.788192954654244 -> (2.9647075216215226, 13.82424110990462), -3.7259281425454516 -> (1.1757588900953793, 14.999999999999998), -3.630667958772349 -> (5.498887306391312, 20.49888730639131), -3.5248355152381885 -> (3.6185151659764694, 24.11740247236778), -3.4611932... | |
scala> bs.map(_._1) | |
res2: Array[Double] = Array(0.29, 0.27, 0.194, 0.182, 0.187, 0.183, 0.182, 0.184, 0.181, 0.182) | |
scala> bs.map(_._1).sum / 10 | |
res3: Double = 0.20350000000000001 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment