Created
February 22, 2012 08:35
-
-
Save ryan5500/1883438 to your computer and use it in GitHub Desktop.
naive bayes classifier in "Programming Collective Intelligence" ruby implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
class Classifier | |
def initialize(get_features) | |
@feature_count = Hash.new {|h, k| h[k] = Hash.new {|h1, k1| h1[k1] = 0}} | |
@category_count = Hash.new {|h, k| h[k] = 0} | |
@get_features = get_features | |
end | |
def inc_feature_count(feature, category) | |
@feature_count[feature][category] += 1 | |
end | |
def inc_category_count(category) | |
@category_count[category] += 1 | |
end | |
def feature_count(feature, category) | |
if @feature_count.has_key?(feature) && | |
@feature_count[feature].has_key?(category) | |
return @feature_count[feature][category].to_f | |
end | |
0.0 | |
end | |
def category_count(category) | |
if @category_count.has_key? category | |
return @category_count[category].to_f | |
end | |
0 | |
end | |
def total_count | |
@category_count.values.inject(0) {|sum, i| sum += i } | |
end | |
def categories | |
@category_count.keys | |
end | |
def train(doc, category) | |
features = @get_features.call(doc) | |
features.each do |f| | |
inc_feature_count(f, category) | |
end | |
inc_category_count(category) | |
end | |
def feature_prob(feature, category) | |
return 0 if category_count(category) == 0 | |
feature_count(feature, category) / category_count(category) | |
end | |
def weighted_prob(feature, category, feature_prob, weight=1.0, average_prob = 0.5) | |
basic_prob = feature_prob(feature, category) | |
total = categories.inject(0) {|sum, c| sum += feature_count(feature, c)} | |
(weight * average_prob + total * basic_prob) / (weight + total) | |
end | |
def self.sample_train(classifier_obj) | |
classifier_obj.train('Nobody owns the water.', 'good') | |
classifier_obj.train('the quick rabbit jumps fences', 'good') | |
classifier_obj.train('buy pharmaceuticals now', 'bad') | |
classifier_obj.train('make quick money at the online casino', 'bad') | |
classifier_obj.train('the quick brown fox jumps', 'good') | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
class Document | |
def self.get_words(doc) | |
words = doc.split(/\W/).collect {|w| w.downcase} | |
words.uniq | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env ruby | |
# coding: utf-8 | |
require './document.rb' | |
require './naive_bayes.rb' | |
c = NaiveBeyes.new(Document.method(:get_words)) | |
NaiveBeyes.sample_train(c) #training with sample data | |
puts c.classify('quick rabbit', 'unknown') # => 'good' | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
require './classifier.rb' | |
class NaiveBeyes < Classifier | |
def initialize(get_features) | |
super(get_features) | |
@thresholds = {} | |
end | |
def set_threshold(category, t) | |
@thresholds[category] = t | |
end | |
def threshold(category) | |
return 0 unless @thresholds.has_key? category | |
@thresholds[category] | |
end | |
def classify(item, default=nil) | |
probs = {} | |
max = 0.0 | |
best = nil | |
categories.each do |category| | |
probs[category] = prob(item, category) | |
if probs[category] > max | |
max = probs[category] | |
best = category | |
end | |
end | |
probs.each do |category, value| | |
next if category == best | |
puts value | |
return default if value * threshold(best) > probs[best] | |
end | |
best | |
end | |
def document_prob(item, category) | |
features = @get_features.call(item) | |
p = 1 | |
features.each do |f| | |
p *= weighted_prob(f, category, self.method(:feature_prob)) | |
end | |
p | |
end | |
def prob(item, category) | |
category_prob = category_count(category) / total_count | |
document_prob = document_prob(item, category) | |
document_prob * category_prob | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment