Created
May 20, 2015 08:22
-
-
Save geekan/471cc0d10f7ecfc769fc to your computer and use it in GitHub Desktop.
simple scala functions for cosine similarity
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Object in scala for calculating cosine similarity | |
* Reuben Sutton - 2012 | |
* More information: http://en.wikipedia.org/wiki/Cosine_similarity | |
*/ | |
object CosineSimilarity { | |
/* | |
* This method takes 2 equal length arrays of integers | |
* It returns a double representing similarity of the 2 arrays | |
* 0.9925 would be 99.25% similar | |
* (x dot y)/||X|| ||Y|| | |
*/ | |
def cosineSimilarity(x: Array[Int], y: Array[Int]): Double = { | |
require(x.size == y.size) | |
dotProduct(x, y)/(magnitude(x) * magnitude(y)) | |
} | |
/* | |
* Return the dot product of the 2 arrays | |
* e.g. (a[0]*b[0])+(a[1]*a[2]) | |
*/ | |
def dotProduct(x: Array[Int], y: Array[Int]): Int = { | |
(for((a, b) <- x zip y) yield a * b) sum | |
} | |
/* | |
* Return the magnitude of an array | |
* We multiply each element, sum it, then square root the result. | |
*/ | |
def magnitude(x: Array[Int]): Double = { | |
math.sqrt(x map(i => i*i) sum) | |
} | |
} | |
def similarity(t1: Map[String, Int], t2: Map[String, Int]): Double = { | |
//word, t1 freq, t2 freq | |
val m = scala.collection.mutable.HashMap[String, (Int, Int)]() | |
val sum1 = t1.foldLeft(0d) {case (sum, (word, freq)) => | |
m += word ->(freq, 0) | |
sum + freq | |
} | |
val sum2 = t2.foldLeft(0d) {case (sum, (word, freq)) => | |
m.get(word) match { | |
case Some((freq1, _)) => m += word ->(freq1, freq) | |
case None => m += word ->(0, freq) | |
} | |
sum + freq | |
} | |
val (p1, p2, p3) = m.foldLeft((0d, 0d, 0d)) {case ((s1, s2, s3), e) => | |
val fs = e._2 | |
val f1 = fs._1 / sum1 | |
val f2 = fs._2 / sum2 | |
(s1 + f1 * f2, s2 + f1 * f1, s3 + f2 * f2) | |
} | |
val cos = p1 / (Math.sqrt(p2) * Math.sqrt(p3)) | |
cos | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment