Created
April 7, 2012 22:07
-
-
Save OlegYch/2332409 to your computer and use it in GitHub Desktop.
naive bayes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class NB { | |
var examples = List[(String, Iterable[String])]() | |
/** | |
* Put your code for adding information to your NB classifier here | |
*/ | |
def addExample(klass: String, words: Iterable[String]): Unit = { | |
examples ::=(klass, words) | |
} | |
def count[T, TT](i: Iterable[T])(pf: PartialFunction[T, (TT, Double)]) = | |
i.view.collect(pf).foldLeft(Map[TT, Double]().withDefaultValue(0.0)) { | |
case (m, (tt, c)) => m + (tt -> (m(tt) + c)) | |
} | |
lazy val classes = examples.view.map(_._1).toSet | |
lazy val words = examples.view.flatMap(_._2).toSet | |
lazy val prior = count(examples) { | |
case (c, words) => (c, 1.0 / examples.size) | |
} | |
lazy val count_c = count(examples) { | |
case (c, words) => (c -> words.size) | |
} | |
lazy val `word->class` = examples.flatMap {case (c, words) => words.map((_, c))} | |
lazy val count_w_c = count(`word->class`) { | |
case tuple => tuple -> 1 | |
} | |
lazy val p_w_c = Map[(String, String), Double]().withDefault { | |
case wc@(w, c) => (count_w_c(wc) + 1) / (count_c(c) + words.size) | |
} | |
lazy val inspect = { | |
println(classes) | |
println(prior) | |
println(count_c) | |
println(count_w_c.take(10)) | |
println(`word->class`.toSet[(String, String)].map(p_w_c(_)).take(10)) | |
} | |
/** | |
* Put your code here for deciding the class of the input file. | |
* Currently, it just randomly chooses "pos" or "negative" | |
*/ | |
def classify(words: Iterable[String]): String = { | |
inspect | |
val pwc = classes.map(c => (c, (prior(c) :: words.toList.map(w => p_w_c(w -> c))))) | |
val classification = pwc.map {case (c, pwc) => (c, pwc.map(math.log).sum)} | |
// println(classification -> words.take(10)) | |
classification.maxBy(_._2)._1 | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment