Skip to content

Instantly share code, notes, and snippets.

@dragos
Created November 21, 2015 11:36
Show Gist options
  • Save dragos/7c2d3ec962ee2e6862f3 to your computer and use it in GitHub Desktop.
Save dragos/7c2d3ec962ee2e6862f3 to your computer and use it in GitHub Desktop.
// fails with:
// $ /opt/scala-2.10.4/bin/scalac -d /tmp src/main/scala/infer.scala -cp ../spark/assembly/target/scala-2.10/spark-assembly-1.6.0-SNAPSHOT-hadoop2.2.0.jar
// src/main/scala/infer.scala:27: error: missing parameter type for expanded function ((x$2) => x$2._2)
// ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect()
// ^
// one error found
//
import org.apache.spark.sql.functions._
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.TypedColumn
import org.apache.spark._
import org.apache.spark.sql._
/** An `Aggregator` that adds up any numeric type returned by the given function. */
class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
val numeric = implicitly[Numeric[N]]
override def zero: N = numeric.zero
override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
override def finish(reduction: N): N = reduction
}
class Main {
val sc = new SparkContext(new SparkConf(true))
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
def test() {
def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment