Created
July 13, 2013 19:49
-
-
Save ptrelford/5991986 to your computer and use it in GitHub Desktop.
Titanic: Machine Learning from Disaster guided F# script for the Kaggle predictive modelling competition of the same name.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module DecisionTree | |
open System.Collections.Generic | |
module internal Tuple = | |
open Microsoft.FSharp.Reflection | |
let toArray = FSharpValue.GetTupleFields | |
module internal Array = | |
let removeAt i (xs:'a[]) = [|yield! xs.[..i-1];yield! xs.[i+1..]|] | |
let splitDataSet(dataSet:obj[][], axis, value) = [| | |
for featVec in dataSet do | |
if featVec.[axis] = value then | |
yield featVec |> Array.removeAt axis | |
|] | |
let calcShannonEnt(dataSet:obj[][]) = | |
let numEntries = dataSet.Length | |
dataSet | |
|> Seq.countBy (fun featVec -> featVec.[featVec.Length-1]) | |
|> Seq.sumBy (fun (key,count) -> | |
let prob = float count / float numEntries | |
-prob * log(prob)/log(2.0) | |
) | |
let chooseBestFeatureToSplit(dataSet:obj[][]) = | |
let numFeatures = dataSet.[0].Length - 1 | |
let baseEntropy = calcShannonEnt(dataSet) | |
[0..numFeatures-1] |> List.map (fun i -> | |
let featList = [for example in dataSet -> example.[i]] | |
let newEntropy = | |
let uniqueValues = Seq.distinct featList | |
uniqueValues |> Seq.sumBy (fun value -> | |
let subDataSet = splitDataSet(dataSet, i, value) | |
let prob = float subDataSet.Length / float dataSet.Length | |
prob * calcShannonEnt(subDataSet) | |
) | |
let infoGain = baseEntropy - newEntropy | |
i, infoGain | |
) | |
|> List.maxBy snd |> fst | |
let majorityCnt(classList:obj[]) = | |
let classCount = Dictionary() | |
for vote in classList do | |
if not <| classCount.ContainsKey(vote) then | |
classCount.Add(vote,0) | |
classCount.[vote] <- classCount.[vote] + 1 | |
[for kvp in classCount -> kvp.Key, kvp.Value] | |
|> List.sortBy (snd >> (~-)) | |
|> List.head | |
|> fst | |
type Label = string | |
type Value = obj | |
type Tree = Leaf of Value | Branch of Label * (Value * Tree)[] | |
let rec createTree(dataSet:obj[][], labels:string[]) = | |
let classList = [|for example in dataSet -> example.[example.Length-1]|] | |
if classList |> Seq.forall((=) classList.[0]) | |
then Leaf(classList.[0]) | |
elif dataSet.[0].Length = 1 | |
then Leaf(majorityCnt(classList)) | |
else | |
let bestFeat = chooseBestFeatureToSplit(dataSet) | |
let bestFeatLabel = labels.[bestFeat] | |
let labels = labels |> Array.removeAt bestFeat | |
let featValues = [|for example in dataSet -> example.[bestFeat]|] | |
let uniqueVals = featValues |> Seq.distinct |> Seq.toArray | |
let subTrees = | |
[|for value in uniqueVals -> | |
let subLabels = labels.[*] | |
let split = splitDataSet(dataSet, bestFeat, value) | |
value, createTree(split, subLabels)|] | |
Branch(bestFeatLabel, subTrees) | |
let mode (dataSet:obj[][]) = | |
dataSet |> Seq.countBy (fun example -> example.[example.Length-1]) | |
|> Seq.maxBy snd |> fst | |
let rec classify(inputTree, featLabels:string[], testVec:obj[]) = | |
match inputTree with | |
| Leaf(x) -> Some x | |
| Branch(s,xs) -> | |
let featIndex = featLabels |> Array.findIndex ((=) s) | |
xs |> Array.tryPick (fun (value,tree) -> | |
if testVec.[featIndex] = value | |
then classify(tree, featLabels, testVec) | |
else None | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module TallyHo | |
[<AutoOpen>] | |
module internal Array = | |
/// Tally up items that match specified criteria | |
let tally criteria items = | |
items |> Array.filter criteria |> Seq.length | |
/// Percentage of items that match specified criteria | |
let percentage criteria items = | |
let total = items |> Seq.length | |
let count = items |> tally criteria | |
float count * 100.0 / float total | |
/// Where = filter | |
let where f xs = Array.filter f xs | |
/// F# interactive friendly groupBy | |
let groupBy f xs = | |
xs | |
|> Seq.groupBy f |> Seq.toArray | |
|> Array.map (fun (k,vs) -> k, vs |> Seq.toArray) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// Titanic: Machine Learning from Disaster | |
// | |
#load "TallyHo.fs" | |
#load "DecisionTree.fs" | |
#r "lib\FSharp.Data.dll" | |
open TallyHo | |
open DecisionTree | |
open FSharp.Data | |
// Load training data | |
let [<Literal>] path = "C:/titanic/train.csv" | |
type Train = CsvProvider<path,InferRows=0> | |
type Passenger = Train.Row | |
let passengers : Passenger[] = | |
Train.Load(path).Take(600).Data | |
|> Seq.toArray | |
// 1. Discover statistics - simple features | |
let female (passenger:Passenger) = passenger.Sex = "female" | |
let survived (passenger:Passenger) = passenger.Survived = 1 | |
// Female passengers | |
let females = passengers |> where female | |
let femaleSurvivors = females |> tally survived | |
let femaleSurvivorsPc = females |> percentage survived | |
// a) Children under 10 | |
// Your code here <---- | |
// b) Passesngers over 50 | |
// Your code hre <-- | |
// c) Upper class passengers | |
// Your code here <--- | |
// 2. Discover statistics - groups | |
/// Survival rate of a criterias group | |
let survivalRate criteria = | |
passengers |> Array.groupBy criteria | |
|> Array.map (fun (key,matching) -> | |
key, matching |> Array.percentage survived | |
) | |
let embarked = survivalRate (fun p -> p.Embarked) | |
// a) By passenger class | |
// Your code here <--- | |
// b) By age group (under 10, adult, over 50) | |
// Your code here <--- | |
// 3. Scoring | |
let testPassengers : Passenger[] = | |
Train.Load(path).Skip(600).Data | |
|> Seq.toArray | |
let score f = testPassengers |> Array.percentage (fun p -> f p = survived p) | |
let notSurvived (p:Passenger) = false | |
let notSurvivedRate = score notSurvived | |
// a) Score by embarked point | |
// Your code here <--- | |
// b) Construct function to score over 80% | |
// Your code here <--- | |
// 4. Decision trees | |
let labels = | |
[|"sex"; "class"|] | |
let features (p:Passenger) : obj[] = | |
[|p.Sex; p.Pclass|] | |
let dataSet : obj[][] = | |
[|for p in passengers -> | |
[|yield! features p; | |
yield box (p.Survived = 1)|] |] | |
let tree = createTree(dataSet, labels) | |
// Classify | |
let test (p:Passenger) = | |
match classify(tree, labels, features p) with | |
| Some(x) -> x | |
| None -> mode dataSet | |
:?> bool | |
let treeRate = score test | |
// a) Optimize features |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment