Skip to content

Instantly share code, notes, and snippets.

@erikerlandson
Last active July 6, 2019 19:12
Show Gist options
  • Save erikerlandson/b0e106a4dbaf7f80b4f4f3a21f05f892 to your computer and use it in GitHub Desktop.
Save erikerlandson/b0e106a4dbaf7f80b4f4f3a21f05f892 to your computer and use it in GitHub Desktop.
Benchmarking Description for Spark UDIA pull request
package org.apache.spark.countSerDe
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.expressions.UserDefinedImperativeAggregator
import org.apache.spark.sql.types._
@SQLUserDefinedType(udt = classOf[CountSerDeUDT])
case class CountSerDeSQL(nSer: Int, nDeSer: Int, sum: Double)
class CountSerDeUDT extends UserDefinedType[CountSerDeSQL] {
def userClass: Class[CountSerDeSQL] = classOf[CountSerDeSQL]
override def typeName: String = "count-ser-de"
private[spark] override def asNullable: CountSerDeUDT = this
def sqlType: DataType = StructType(
StructField("nSer", IntegerType, false) ::
StructField("nDeSer", IntegerType, false) ::
StructField("sum", DoubleType, false) ::
Nil)
def serialize(sql: CountSerDeSQL): Any = {
val row = new GenericInternalRow(3)
row.setInt(0, 1 + sql.nSer)
row.setInt(1, sql.nDeSer)
row.setDouble(2, sql.sum)
row
}
def deserialize(any: Any): CountSerDeSQL = any match {
case row: InternalRow if (row.numFields == 3) =>
CountSerDeSQL(row.getInt(0), 1 + row.getInt(1), row.getDouble(2))
case u => throw new Exception(s"failed to deserialize: $u")
}
override def equals(obj: Any): Boolean = {
obj match {
case _: CountSerDeUDT => true
case _ => false
}
}
override def hashCode(): Int = classOf[CountSerDeUDT].getName.hashCode()
}
case object CountSerDeUDT extends CountSerDeUDT
case object CountSerDeUDIA extends UserDefinedImperativeAggregator[CountSerDeSQL] {
import org.apache.spark.unsafe.Platform
def inputSchema: StructType = StructType(StructField("x", DoubleType) :: Nil)
def resultType: DataType = CountSerDeUDT
def deterministic: Boolean = false
def initial: CountSerDeSQL = CountSerDeSQL(0, 0, 0)
def update(agg: CountSerDeSQL, input: Row): CountSerDeSQL =
agg.copy(sum = agg.sum + input.getDouble(0))
def merge(agg1: CountSerDeSQL, agg2: CountSerDeSQL): CountSerDeSQL =
CountSerDeSQL(agg1.nSer + agg2.nSer, agg1.nDeSer + agg2.nDeSer, agg1.sum + agg2.sum)
def evaluate(agg: CountSerDeSQL): Any = agg
def serialize(agg: CountSerDeSQL): Array[Byte] = {
val CountSerDeSQL(ns, nd, s) = agg
val byteArray = new Array[Byte](4 + 4 + 8)
Platform.putInt(byteArray, Platform.BYTE_ARRAY_OFFSET, ns + 1)
Platform.putInt(byteArray, Platform.BYTE_ARRAY_OFFSET + 4, nd)
Platform.putDouble(byteArray, Platform.BYTE_ARRAY_OFFSET + 8, s)
byteArray
}
def deserialize(data: Array[Byte]): CountSerDeSQL = {
val ns = Platform.getInt(data, Platform.BYTE_ARRAY_OFFSET)
val nd = Platform.getInt(data, Platform.BYTE_ARRAY_OFFSET + 4)
val s = Platform.getDouble(data, Platform.BYTE_ARRAY_OFFSET + 8)
CountSerDeSQL(ns, nd + 1, s)
}
}
case object CountSerDeUDAF extends UserDefinedAggregateFunction {
def deterministic: Boolean = false
def inputSchema: StructType = StructType(StructField("x", DoubleType) :: Nil)
def bufferSchema: StructType = StructType(StructField("count-ser-de", CountSerDeUDT) :: Nil)
def dataType: DataType = CountSerDeUDT
def initialize(buf: MutableAggregationBuffer): Unit = {
buf(0) = CountSerDeSQL(0, 0, 0)
}
def update(buf: MutableAggregationBuffer, input: Row): Unit = {
val sql = buf.getAs[CountSerDeSQL](0)
buf(0) = sql.copy(sum = sql.sum + input.getDouble(0))
}
def merge(buf1: MutableAggregationBuffer, buf2: Row): Unit = {
val sql1 = buf1.getAs[CountSerDeSQL](0)
val sql2 = buf2.getAs[CountSerDeSQL](0)
buf1(0) = CountSerDeSQL(sql1.nSer + sql2.nSer, sql1.nDeSer + sql2.nDeSer, sql1.sum + sql2.sum)
}
def evaluate(buf: Row): Any = buf.getAs[CountSerDeSQL](0)
}
package org.apache.spark.tdigest
import org.isarnproject.sketches.TDigest
import org.isarnproject.sketches.tdmap.TDigestMap
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.expressions.UserDefinedImperativeAggregator
import org.apache.spark.sql.types._
@SQLUserDefinedType(udt = classOf[TDigestUDT])
case class TDigestSQL(tdigest: TDigest)
class TDigestUDT extends UserDefinedType[TDigestSQL] {
def userClass: Class[TDigestSQL] = classOf[TDigestSQL]
override def pyUDT: String = "isarnproject.sketches.udt.tdigest.TDigestUDT"
override def typeName: String = "tdigest"
override def equals(obj: Any): Boolean = {
obj match {
case _: TDigestUDT => true
case _ => false
}
}
override def hashCode(): Int = classOf[TDigestUDT].getName.hashCode()
private[spark] override def asNullable: TDigestUDT = this
def sqlType: DataType = StructType(
StructField("delta", DoubleType, false) ::
StructField("maxDiscrete", IntegerType, false) ::
StructField("nclusters", IntegerType, false) ::
StructField("clustX", ArrayType(DoubleType, false), false) ::
StructField("clustM", ArrayType(DoubleType, false), false) ::
Nil)
def serialize(tdsql: TDigestSQL): Any = serializeTD(tdsql.tdigest)
def deserialize(datum: Any): TDigestSQL = TDigestSQL(deserializeTD(datum))
def serializeTD(td: TDigest): InternalRow = {
val TDigest(delta, maxDiscrete, nclusters, clusters) = td
val row = new GenericInternalRow(5)
row.setDouble(0, delta)
row.setInt(1, maxDiscrete)
row.setInt(2, nclusters)
val clustX = clusters.keys.toArray
val clustM = clusters.values.toArray
row.update(3, UnsafeArrayData.fromPrimitiveArray(clustX))
row.update(4, UnsafeArrayData.fromPrimitiveArray(clustM))
row
}
def deserializeTD(datum: Any): TDigest = datum match {
case row: InternalRow =>
require(row.numFields == 5, s"expected row length 5, got ${row.numFields}")
val delta = row.getDouble(0)
val maxDiscrete = row.getInt(1)
val nclusters = row.getInt(2)
val clustX = row.getArray(3).toDoubleArray()
val clustM = row.getArray(4).toDoubleArray()
val clusters = clustX.zip(clustM)
.foldLeft(TDigestMap.empty) { case (td, e) => td + e }
TDigest(delta, maxDiscrete, nclusters, clusters)
case u => throw new Exception(s"failed to deserialize: $u")
}
}
case object TDigestUDT extends TDigestUDT
case class TDigestUDAF(deltaV: Double, maxDiscreteV: Int) extends
UserDefinedAggregateFunction {
def deterministic: Boolean = false
def inputSchema: StructType = StructType(StructField("x", DoubleType) :: Nil)
def bufferSchema: StructType = StructType(StructField("tdigest", TDigestUDT) :: Nil)
def dataType: DataType = TDigestUDT
def initialize(buf: MutableAggregationBuffer): Unit = {
buf(0) = TDigestSQL(TDigest.empty(deltaV, maxDiscreteV))
}
def update(buf: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
buf(0) = TDigestSQL(buf.getAs[TDigestSQL](0).tdigest + input.getDouble(0))
}
}
def merge(buf1: MutableAggregationBuffer, buf2: Row): Unit = {
buf1(0) = TDigestSQL(buf1.getAs[TDigestSQL](0).tdigest ++ buf2.getAs[TDigestSQL](0).tdigest)
}
def evaluate(buf: Row): Any = buf.getAs[TDigestSQL](0)
}
case class TDigestUDIA(deltaV: Double, maxDiscreteV: Int) extends
UserDefinedImperativeAggregator[TDigest] {
def inputSchema: StructType = StructType(StructField("x", DoubleType) :: Nil)
def resultType: DataType = TDigestUDT
def deterministic: Boolean = false
def initial: TDigest = TDigest.empty(deltaV, maxDiscreteV)
def update(agg: TDigest, input: Row): TDigest =
if (input.isNullAt(0)) agg else agg + input.getDouble(0)
def merge(agg1: TDigest, agg2: TDigest): TDigest = agg1 ++ agg2
def evaluate(agg: TDigest): Any = TDigestSQL(agg)
import java.io._
// scalastyle:off classforname
class ObjectInputStreamWithCustomClassLoader(
inputStream: InputStream) extends ObjectInputStream(inputStream) {
override def resolveClass(desc: java.io.ObjectStreamClass): Class[_] = {
try { Class.forName(desc.getName, false, getClass.getClassLoader) }
catch { case ex: ClassNotFoundException => super.resolveClass(desc) }
}
}
// scalastyle:off classforname
def serialize(agg: TDigest): Array[Byte] = {
val bufout = new ByteArrayOutputStream()
val obout = new ObjectOutputStream(bufout)
obout.writeObject(agg)
bufout.toByteArray
}
def deserialize(data: Array[Byte]): TDigest = {
val bufin = new ByteArrayInputStream(data)
val obin = new ObjectInputStreamWithCustomClassLoader(bufin)
obin.readObject().asInstanceOf[TDigest]
}
}
object Benchmark {
def apply[T](blk: => T): (Double, T) = {
val t0 = System.currentTimeMillis
val v = blk
val t = System.currentTimeMillis
((t - t0).toDouble / 1000.0, v)
}
def sample[T](samples: Int)(blk: => T): Array[(Double, T)] = {
Array.fill(samples) {
val t0 = System.currentTimeMillis
val v = blk
val t = System.currentTimeMillis
((t - t0).toDouble / 1000.0, v)
}
}
}
Using Scala version 2.12.8 (OpenJDK 64-Bit Server VM, Java 1.8.0_212)
Type in expressions to have them evaluated.
Type :help for more information.
scala> import scala.util.Random._, org.apache.spark.countSerDe._, org.apache.spark.sql.Row, org.apache.spark.tdigest._
import scala.util.Random._
import org.apache.spark.countSerDe._
import org.apache.spark.sql.Row
import org.apache.spark.tdigest._
scala> sc.setLogLevel("ERROR")
scala> val udaf = CountSerDeUDAF
udaf: org.apache.spark.countSerDe.CountSerDeUDAF.type = CountSerDeUDAF
scala> val udia = CountSerDeUDIA
udia: org.apache.spark.countSerDe.CountSerDeUDIA.type = CountSerDeUDIA
scala> val data = sc.parallelize(Vector.fill(2000000){(nextInt(2), nextGaussian, nextGaussian)}, 5).toDF("cat", "x1", "x2")
data: org.apache.spark.sql.DataFrame = [cat: int, x1: double ... 1 more field]
scala> val bs = Benchmark.sample(50) { data.agg(udaf($"x1"), udaf($"x2")).first }
bs: Array[(Double, org.apache.spark.sql.Row)] = Array((7.986,[CountSerDeSQL(2000012,2000012,-54.30894451561437),CountSerDeSQL(2000012,2000012,208.8748610102166)]), (7.725,[CountSerDeSQL(2000012,2000012,-54.30894451561437),CountSerDeSQL(2000012,2000012,208.8748610102166)]), (8.81,[CountSerDeSQL(2000012,2000012,-54.30894451561437),CountSerDeSQL(2000012,2000012,208.8748610102166)]), (7.62,[CountSerDeSQL(2000012,2000012,-54.30894451561437),CountSerDeSQL(2000012,2000012,208.8748610102166)]), (9.375,[CountSerDeSQL(2000012,2000012,-54.30894451561437),CountSerDeSQL(2000012,2000012,208.8748610102166)]), (8.504,[CountSerDeSQL(2000012,2000012,-54.30894451561437),CountSerDeSQL(2000012,2000012,208.8748610102166)]), (8.336,[CountSerDeSQL(2000012,2000012,-54.30894451561437),C...
scala> bs.map(_._1)
res0: Array[Double] = Array(7.986, 7.725, 8.81, 7.62, 9.375, 8.504, 8.336, 8.627, 9.151, 8.124, 9.028, 8.849, 9.401, 9.318, 8.703, 8.815, 8.966, 9.302, 9.111, 8.933, 9.022, 8.893, 9.653, 9.284, 8.205, 9.627, 8.947, 9.152, 9.078, 9.839, 9.08, 9.246, 7.5, 8.279, 8.793, 9.492, 9.099, 9.096, 9.132, 9.206, 8.585, 9.093, 8.946, 8.848, 8.366, 8.795, 9.01, 8.946, 9.011, 6.885)
scala> bs.map(_._1).sum / 50
res1: Double = 8.835840000000001
scala> val bs = Benchmark.sample(50) { data.agg(udia($"x1"), udia($"x2")).first }
bs: Array[(Double, org.apache.spark.sql.Row)] = Array((7.514,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (5.283,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (6.291,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (5.625,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (6.457,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (4.961,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (6.089,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (5.305,[CountSerDeSQL(6,6,-54.30894451561437),CountSerDeSQL(6,6,208.8748610102166)]), (6.658,[CountSerDeSQL(6,6,-54...
scala> bs.map(_._1)
res2: Array[Double] = Array(7.514, 5.283, 6.291, 5.625, 6.457, 4.961, 6.089, 5.305, 6.658, 5.765, 5.521, 6.172, 6.694, 6.169, 5.507, 6.878, 6.485, 5.872, 5.577, 5.875, 6.315, 7.104, 5.829, 5.975, 5.826, 5.349, 6.513, 5.836, 6.129, 6.868, 6.09, 5.857, 7.068, 6.866, 6.203, 6.521, 6.108, 6.664, 5.94, 6.958, 6.163, 6.785, 6.013, 6.493, 6.23, 6.054, 7.12, 5.849, 6.898, 6.173)
scala> bs.map(_._1).sum / 50
res3: Double = 6.209899999999999
Using Scala version 2.12.8 (OpenJDK 64-Bit Server VM, Java 1.8.0_212)
Type in expressions to have them evaluated.
Type :help for more information.
scala> import scala.util.Random._, org.apache.spark.countSerDe._, org.apache.spark.sql.Row, org.apache.spark.tdigest._
import scala.util.Random._
import org.apache.spark.countSerDe._
import org.apache.spark.sql.Row
import org.apache.spark.tdigest._
scala> sc.setLogLevel("ERROR")
scala> val data = sc.parallelize(Vector.fill(100000){(nextInt(2), nextGaussian, nextGaussian)}, 5).toDF("cat", "x1", "x2")
data: org.apache.spark.sql.DataFrame = [cat: int, x1: double ... 1 more field]
scala> val udia = TDigestUDIA(0.5, 0)
udia: org.apache.spark.tdigest.TDigestUDIA = TDigestUDIA(0.5,0)
scala> val udaf = TDigestUDAF(0.5, 0)
udaf: org.apache.spark.tdigest.TDigestUDAF = TDigestUDAF(0.5,0)
scala> val bs = Benchmark.sample(10) { data.agg(udaf($"x1")).first }
bs: Array[(Double, org.apache.spark.sql.Row)] = Array((15.694,[TDigestSQL(TDigest(0.5,0,133,TDigestMap(-4.421763563584745 -> (1.0, 1.0), -4.254974064796795 -> (1.0, 2.0), -4.232552385036448 -> (1.0, 3.0), -4.090629063164773 -> (0.15027423113120597, 3.150274231131206), -4.026284332714937 -> (1.9991595262149298, 5.149433757346136), -4.011367664781805 -> (0.8505662426538643, 6.0), -3.8819633785079186 -> (2.6052762694868314, 8.605276269486831), -3.855805587944206 -> (2.2542573187962667, 10.859533588283098), -3.788192954654244 -> (2.9647075216215226, 13.82424110990462), -3.7259281425454516 -> (1.1757588900953793, 14.999999999999998), -3.630667958772349 -> (5.498887306391312, 20.49888730639131), -3.5248355152381885 -> (3.6185151659764703, 24.11740247236778), -3.46119...
scala> bs.map(_._1)
res0: Array[Double] = Array(15.694, 18.268, 18.733, 18.623, 18.717, 18.924, 18.921, 18.939, 18.815, 18.995)
scala> bs.map(_._1).sum / 10
res1: Double = 18.462899999999998
scala> val bs = Benchmark.sample(10) { data.agg(udia($"x1")).first }
bs: Array[(Double, org.apache.spark.sql.Row)] = Array((0.29,[TDigestSQL(TDigest(0.5,0,133,TDigestMap(-4.421763563584745 -> (1.0, 1.0), -4.254974064796795 -> (1.0, 2.0), -4.232552385036448 -> (1.0, 3.0), -4.090629063164773 -> (0.15027423113120597, 3.150274231131206), -4.026284332714937 -> (1.9991595262149298, 5.149433757346136), -4.011367664781805 -> (0.8505662426538643, 6.0), -3.8819633785079186 -> (2.6052762694868314, 8.605276269486831), -3.855805587944206 -> (2.2542573187962667, 10.859533588283098), -3.788192954654244 -> (2.9647075216215226, 13.82424110990462), -3.7259281425454516 -> (1.1757588900953793, 14.999999999999998), -3.630667958772349 -> (5.498887306391312, 20.49888730639131), -3.5248355152381885 -> (3.6185151659764694, 24.11740247236778), -3.4611932...
scala> bs.map(_._1)
res2: Array[Double] = Array(0.29, 0.27, 0.194, 0.182, 0.187, 0.183, 0.182, 0.184, 0.181, 0.182)
scala> bs.map(_._1).sum / 10
res3: Double = 0.20350000000000001
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment