Created
February 16, 2016 14:48
-
-
Save bricef/687066730ba299b46a63 to your computer and use it in GitHub Desktop.
A bayesian spam classifier
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Learn more about F# at http://fsharp.net | |
// See the 'F# Tutorial' project for more help. | |
open Argu | |
type Label = Spam | Ham | |
type Corpus = { | |
spamFreqs:Map<string, int>; | |
hamFreqs:Map<string, int>; | |
spamCount:int; | |
hamCount:int; | |
totalCount:int | |
} | |
let classify (msg:string) : Label = | |
Spam | |
let accuracyRate | |
(classifier:string->Label) | |
(labeledData:seq<Label*string>) : float = | |
let numberCorrectlyClassified = | |
labeledData | |
|> Seq.sumBy (fun (label,msg) -> if (classifier msg) = label then 1.0 else 0.0) | |
numberCorrectlyClassified / float (Seq.length labeledData) | |
let string2label str = | |
match str with | |
| "ham" -> Ham | |
| _ -> Spam | |
let makeRandomPredicate fractionTrue = | |
let r = System.Random() | |
let predicate x = | |
r.NextDouble() < fractionTrue | |
predicate | |
let features (msg:string) : seq<string> = | |
// Most naive thing we could think of | |
msg.Split( [|' '|] ) |> Array.toSeq | |
let countLabel (label:Label) (labeledWords:seq<Label*string>): int = | |
labeledWords | |
|> Seq.filter (fun (l,w)-> l = label) | |
|> Seq.length | |
let frequencies label labeledWords : Map<string, int> = | |
labeledWords | |
|> Seq.filter (fun (l,w)-> l = label) | |
|> Seq.groupBy (fun (l,w) -> w) | |
|> Seq.map (fun (w,ws)-> (w, Seq.length ws)) | |
|> Map.ofSeq | |
let pHam (corpus:Corpus) : float = | |
(float corpus.hamCount) / (float corpus.totalCount) | |
let pSpam (corpus:Corpus) : float = | |
(float corpus.spamCount) / (float corpus.totalCount) | |
let getCount map word = | |
match (Map.tryFind word map) with | |
| Some(x) -> (float x) | |
| _ -> 0.0001 | |
let pWordGivenSpam word (corpus:Corpus) : float = | |
(getCount corpus.spamFreqs word) / (float corpus.spamCount) | |
let pWordGivenHam word (corpus:Corpus) : float = | |
(getCount corpus.hamFreqs word) / (float corpus.hamCount) | |
let pWord word (corpus:Corpus) : float = | |
(pWordGivenHam word corpus)*(pHam corpus) + (pWordGivenSpam word corpus)*(pSpam corpus) | |
let pHamGivenWords words corpus = | |
let product = | |
words | |
|> Seq.map (fun w -> (pWordGivenHam w corpus) / (pWord w corpus)) | |
|> Seq.reduce (*) | |
product * (pHam corpus) | |
let pSpamGivenWords words corpus = | |
let product = | |
words | |
|> Seq.map (fun w -> (pWordGivenSpam w corpus) / (pWord w corpus)) | |
|> Seq.reduce (*) | |
product * (pSpam corpus) | |
let makeData filename = | |
let path = System.IO.Path.Combine(__SOURCE_DIRECTORY__, filename) | |
let lines = System.IO.File.ReadAllLines(path) | |
let labeledData = | |
lines | |
|> Seq.map (fun l -> l.Split( [|'\t'|] ) ) | |
|> Seq.map (fun stra -> ((string2label stra.[0]), stra.[1]) ) | |
labeledData | |
let partitionRandomly fraction data = | |
data | |
|> Seq.toList | |
|> List.partition (makeRandomPredicate fraction) | |
let makeCorpus data = | |
let words : seq<Label*seq<string>> = | |
data | |
|> Seq.map (fun (label, msg) -> (label, features msg)) | |
let labeledWords : seq<Label*string> = | |
words | |
|> Seq.collect (fun (l,words) -> Seq.map (fun w -> (l,w)) words) | |
let corpus = { | |
spamFreqs = (frequencies Spam labeledWords) ; | |
hamFreqs = (frequencies Ham labeledWords); | |
hamCount = (countLabel Ham labeledWords); | |
spamCount = (countLabel Spam labeledWords); | |
totalCount = (Seq.length labeledWords); | |
} | |
corpus | |
let makeClassifier corpus = | |
let BayesClassify (msg:string) : Label = | |
let words = features msg | |
let pHam = pHamGivenWords words corpus | |
let pSpam = pSpamGivenWords words corpus | |
if pHam > pSpam then Ham else Spam | |
BayesClassify | |
type Arguments = | |
| [<Mandatory>] TrainingData of string | |
| Message of string | |
| Stats | |
with | |
interface IArgParserTemplate with | |
member s.Usage = | |
match s with | |
| TrainingData _ -> "The filename with the data to train the classifier" | |
| Message _ -> "Message to classify" | |
| Stats _ -> "Show the statistics for the clasifier" | |
[<EntryPoint>] | |
let main argv = | |
let parser = ArgumentParser.Create<Arguments>() | |
let results = parser.Parse(argv) | |
let data = makeData (results.GetResult <@ TrainingData @>) | |
let corpus = makeCorpus data | |
let bayesClassify = makeClassifier corpus | |
if results.Contains <@ Message @> then | |
let label = bayesClassify (results.GetResult <@ Message @>) | |
printfn "%A" label | |
if results.Contains <@ Stats @> then | |
let training, validation = partitionRandomly 0.8 data | |
let corpus = makeCorpus training | |
let bayesClassify = makeClassifier corpus | |
printfn "Training size: %A" training.Length | |
printfn "The accuracy is %A" (accuracyRate bayesClassify validation) | |
printfn "Total=%A, Ham=%A, Spam=%A" corpus.totalCount corpus.hamCount corpus.spamCount | |
0 // return an integer exit code | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment