Skip to content

Instantly share code, notes, and snippets.

@jongwook
Created September 19, 2016 03:06
Show Gist options
  • Save jongwook/5d4e78290eaef22cb69abbf68b52e597 to your computer and use it in GitHub Desktop.
Save jongwook/5d4e78290eaef22cb69abbf68b52e597 to your computer and use it in GitHub Desktop.
package com.github.jongwook
import net.recommenders.rival.core.DataModel
import net.recommenders.rival.evaluation.metric.ranking.NDCG
import net.recommenders.rival.evaluation.metric.ranking.NDCG.TYPE
import org.apache.spark.SparkConf
import org.apache.spark.mllib.evaluation.RankingMetrics
import org.apache.spark.sql.SparkSession
import scala.util.{Failure, Success, Random, Try}
/** A rough translation of Kaggle's code for NDCG, found at https://www.kaggle.com/wiki/NormalizedDiscountedCumulativeGain
* Note that these methods only calculate the ranking for one query (usually corresponding to a user),
* while reported evaluation metrics are usually average performance over multiple queries.
*/
object KaggleNDCG {
def evaluateSubmissionSubset(solution: Map[Int, Double], submission: Seq[Int], k: Int): Double = {
val dcg = calculateDcg(k, submission, solution)
val optimal = calculateOptimalDcg(k, solution)
var ndcg = dcg / optimal
if (optimal <= 0) {
ndcg = if (dcg == optimal) 1.0 else 0.0
}
ndcg
}
def calculateDcg(k: Int, items: Seq[Int], itemToRelevances: Map[Int, Double]): Double = {
val relevances = items.map { item =>
itemToRelevances.getOrElse(item, 0.0)
}
calculateDcg(k, relevances)
}
private def calculateOptimalDcg(k: Int, itemToRelevances: Map[Int, Double]): Double = {
calculateDcg(k, itemToRelevances.values.toSeq.sortBy(x => -x))
}
val log2: Double = Math.log(2)
private def calculateDcg(k: Int, numbers: Seq[Double]): Double = {
numbers.take(k).zipWithIndex.map {
case (number, i) => (Math.pow(2.0, number) - 1.0) / Math.log(i + 2) * log2
}.sum
}
/** A test to show that Kaggle's, Rival's, and Spark's NDCG implementations are equivalent */
def main(args: Array[String]): Unit = {
// simple synthetic data for ranking query
val numItems = 10
val rng = new Random(0)
val groundTruth = (1 to numItems).map(x => (x, x.toDouble)).toMap
val prediction = (numItems/2 to numItems).map(x => (x, rng.nextDouble())).toMap
val groundTruthRanking = groundTruth.toSeq.sortBy(-_._2).map(_._1).toArray
val predictionRanking = prediction.toSeq.sortBy(-_._2).map(_._1).toArray
val ats = (1 to numItems).toArray
// Kaggle
printf("%10s", "Kaggle")
for (k <- ats) {
printf("%15.12f", evaluateSubmissionSubset(groundTruth, predictionRanking, k))
}
println()
// Rival
val groundTruthModel = new DataModel[Int, Int]()
val predictionModel = new DataModel[Int, Int]()
groundTruth.foreach { case (item, rating) => groundTruthModel.addPreference(0, item, rating); }
prediction.foreach { case (item, rating) => predictionModel.addPreference(0, item, rating); }
val rivalNDCG = new NDCG[Int, Int](predictionModel, groundTruthModel, 0, ats, TYPE.EXP)
rivalNDCG.compute()
printf("%10s", "Rival")
for (k <- ats) {
printf("%15.12f", rivalNDCG.getValueAt(k))
}
println()
// Spark
Try(Class.forName("org.apache.spark.sql.SparkSession")) match {
case Success(_) =>
val spark = SparkSession.builder().master(new SparkConf().get("spark.master", "local[8]")).getOrCreate()
val predictionAndLabel = spark.sparkContext.parallelize(Seq((predictionRanking, groundTruthRanking)))
val metrics = new RankingMetrics(predictionAndLabel)
val ndcgs = ats.map(metrics.ndcgAt)
printf("%10s", "Spark")
ndcgs.foreach(ndcg => printf("%15.12f", ndcg))
println()
case Failure(_) =>
println("Spark classes not found")
}
}
}
@jongwook
Copy link
Author

Output:

    Kaggle 0.030303030303 0.082598228866 0.248914935313 0.311894658754 0.561995061237 0.572382206514 0.570557025817 0.569754633694 0.569427136387 0.569322389203
     Rival 0.030303030303 0.082598228866 0.248914935313 0.311894658754 0.561995061237 0.572382206514 0.570557025817 0.569754633694 0.569427136387 0.569322389203
     Spark 1.000000000000 1.000000000000 1.000000000000 1.000000000000 1.000000000000 1.000000000000 0.908374555695 0.835891227182 0.776747107522 0.727329844311

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment