Skip to content

Instantly share code, notes, and snippets.

@emesday
Last active April 19, 2019 00:43
Show Gist options
  • Save emesday/e9e588af487b0611371a52696d2ef824 to your computer and use it in GitHub Desktop.
Save emesday/e9e588af487b0611371a52696d2ef824 to your computer and use it in GitHub Desktop.
TopByKeyAggregatorProxy.scala
object TopByKeyAggregatorProxy {
import scala.reflect.runtime.universe._
/**
* Works on rows of the form (K1, K2, V) where K1 & K2 are IDs and V is the score value. Finds
* the top `num` K2 items based on the given Ordering.
*/
def asTypedColumn[K1: TypeTag, K2: TypeTag, V: TypeTag]
(num: Int, ord: Ordering[(K2, V)]): TypedColumn[(K1, K2, V), Array[(K2, V)]] = {
Class.forName("org.apache.spark.ml.recommendation.TopByKeyAggregator")
.getConstructors
.head
.newInstance(new java.lang.Integer(num), ord, typeTag[K1], typeTag[K2], typeTag[V])
.asInstanceOf[org.apache.spark.sql.expressions.Aggregator[_, _, _]]
.toColumn
.asInstanceOf[org.apache.spark.sql.TypedColumn[(K1, K2, V), Array[(K2, V)]]]
}
}
@emesday
Copy link
Author

emesday commented Apr 19, 2019

Usage:

  val ds: Dataset[(Int, Int, Float)] = _
  val aggregator: TypedColumn[(Int, Int, Float), Array[(Int, Float)]] = TopByKeyAggregatorProxy.asTypedColumn[Int, Int, Float](10, Ordering.by(_._2))
  val topByKey: Dataset[(Int, Array[(Int, Float)])] = ds.groupByKey(_._1).agg(aggregator)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment