Skip to content

Instantly share code, notes, and snippets.

@oluies
Created April 14, 2012 16:43
Show Gist options
  • Save oluies/2385780 to your computer and use it in GitHub Desktop.
Save oluies/2385780 to your computer and use it in GitHub Desktop.
StupidBackoffLanguageModel
import collection.GenMap
import scala.math
/**
* Created by IntelliJ IDEA.
* User: orjan
* Date: 2012-03-24
* Time: 15:15
* To change this template use File | Settings | File Templates.
*/
import scala.collection.JavaConverters._
import collection.GenMap
class StupidBackoffLanguageModel(corpus: HolbrookCorpus) extends LanguageModel {
val C = math.log(0.4)
val unigrams = (for (sentence <- corpus.getData.asScala; datum <- sentence.asScala) yield datum.getWord)
val unigramCount: GenMap[String, Double] = unigrams.par.groupBy(e => e).map(e => e._1 -> (e._2.length.toDouble+1.0))
val totalUnigrams: Double = unigramCount.size.toDouble + unigrams.size.toDouble
val bigrams = (for (sentence <- corpus.getData.asScala)
yield sentence.asScala.map(_.getWord).toList.sliding(2).toList.map(l => Pair(l(0), l(1)))).flatten
// count the bigrams
val bigramCount = bigrams.groupBy(e => e).map(kv => (kv._1, kv._2.length.toDouble))
val totalBigrams: Double = bigramCount.size.toDouble
/**Takes a list of strings as argument and returns the log-probability of the
* sentence using your language model. Use whatever data you computed in train() here.
*/
def score(sentence: java.util.List[String]): Double = {
val s: List[String] = sentence.asScala.toList
val vocabulary = totalUnigrams
s.sliding(2).toList.par.map(l => Sbackoff(l, vocabulary)).sum
}
def Sbackoff(l: List[String], vocabulary: Double): Double = {
val w1 = l(1)
val w2 = l(0)
val bi = countBigram(w1, w2)
//S(w1 | w2) = count(w1, w2)/count(w2) if count(w1, w2) > 0,
// where count(w1, w2) is the number of occurrences of w2 preceding w1 in the corpus,
// count(w2) is the number of occurrences of w2 in the corpus
//otherwise:
// S(w1 | w2) = 0.4 * unigram(w1), where unigram(w1) is the score of w1 from the add-1 smoothed unigram model.
if (bi > 0) {
math.log(bi/(countUnigram(w1)-1))
} else {
C + math.log(Punigram(w2) )
}
}
def countUnigram(token: String): Double =
unigramCount.get(token).getOrElse(1.0)
def countBigram(token: String, before: String): Double =
bigramCount.get(Pair(before, token)).getOrElse(0.0)
def Punigram(w:String): Double = {
countUnigram(w) / totalUnigrams
}
//def countTrigram(token: String, before1: String, before2: String): Double =
// trigramCount.get(Triple(before2, before1, token)).getOrElse(0.0)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment