Created
January 21, 2016 22:54
-
-
Save MollsReis/5d58a9de3b06eaf93a4f to your computer and use it in GitHub Desktop.
Multinomial Naive Bayes for a bag of words
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MNB | |
def initialize(examples) | |
@examples = examples.map { |ex| [ex.first.gsub(/[^a-zA-Z]/, ' ').downcase.split, ex.last] } | |
@buckets = @examples.map { |ex| ex.last }.uniq | |
@vocab_size = @examples.map { |ex| ex.first }.flatten.uniq.count | |
@prob_bucket = Hash.new do |hash, bucket| | |
hash[bucket] = @examples.count { |ex| ex.last == bucket } / @examples.count.to_f | |
end | |
@prob_word_given_bucket = Hash.new do |hash, word_bucket| | |
word, bucket = word_bucket.split('__') | |
num = @examples.reduce(0) { |count, ex| ex.last == bucket ? count + ex.first.count(word) : count } + 1 | |
dom = @examples.reduce(0) { |count, ex| ex.last == bucket ? count + ex.first.count : count } + @vocab_size | |
hash[word_bucket] = num / dom.to_f | |
end | |
end | |
def prob_bucket(bucket) | |
@prob_bucket[bucket] | |
end | |
def prob_word_given_bucket(word, bucket) | |
@prob_word_given_bucket[word + '__' + bucket] | |
end | |
def classify(words) | |
@buckets.reduce({}) do |result, bucket| | |
word_prob = words.split.reduce(1) { |product, word| product * prob_word_given_bucket(word, bucket) } | |
result[bucket] = prob_bucket(bucket) * word_prob | |
result | |
end.max_by { |k,v| v }.first | |
end | |
end | |
require 'csv' | |
puts MNB.new(CSV.new(DATA).to_a).classify('china china china tokyo japan').inspect # => "c" | |
__END__ | |
china beijing china,c | |
china china shanghai,c | |
china macao,c | |
tokyo japan china,j |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment