-
-
Save mchav/389df62cb378cb973745d0517f6973e0 to your computer and use it in GitHub Desktop.
| {-# LANGUAGE NumericUnderscores #-} | |
| {-# LANGUAGE OverloadedStrings #-} | |
| {-# LANGUAGE TemplateHaskell #-} | |
| {-# LANGUAGE TypeApplications #-} | |
| import qualified Data.Text as T | |
| import qualified DataFrame as D | |
| import qualified DataFrame.Functions as F | |
| import Data.Char | |
| import Data.Text (Text) | |
| import DataFrame ((|>)) | |
| import DataFrame.DecisionTree | |
| import DataFrame.Functions ((.&&), (.==)) | |
| import System.Random | |
| $(F.declareColumnsFromCsvFile "../../Downloads/train.csv") | |
| main :: IO () | |
| main = do | |
| train <- D.readCsv "../../Downloads/train.csv" | |
| test <- D.readCsv "../../Downloads/test.csv" | |
| let combined = | |
| (train <> test) | |
| |> D.derive | |
| (F.name ticket) | |
| (F.whenPresent (T.filter isAlpha) (F.match "^([A-Za-z][A-Za-z0-9./]*)" ticket)) | |
| |> D.derive (F.name name) (F.match "\\s*([A-Za-z]+)\\." name) | |
| |> D.derive (F.name cabin) (F.whenPresent (T.take 1) cabin) | |
| print combined | |
| let (train', validation) = | |
| D.take | |
| (D.nRows train) | |
| combined | |
| |> D.shuffle (mkStdGen 1894) | |
| |> D.randomSplit (mkStdGen 4232) 0.8 | |
| test' = | |
| D.drop | |
| (D.nRows train) | |
| combined | |
| model = | |
| fitDecisionTree | |
| ( defaultTreeConfig | |
| { maxTreeDepth = 5 | |
| , minSamplesSplit = 25 | |
| , minLeafSize = 15 | |
| , synthConfig = | |
| defaultSynthConfig | |
| { complexityPenalty = 0 | |
| , maxExprDepth = 3 | |
| , disallowedCombinations = | |
| [ (F.name pclass, F.name parch) | |
| , (F.name pclass, F.name sibsp) | |
| , (F.name cabin, F.name ticket) | |
| , (F.name age, F.name fare) | |
| ] | |
| } | |
| } | |
| ) | |
| survived | |
| ( train' | |
| |> D.filterJust (F.name survived) | |
| |> D.exclude [F.name passengerid] | |
| ) | |
| print model | |
| putStrLn "Training accuracy: " | |
| print $ | |
| computeAccuracy | |
| (train' |> D.filterJust (F.name survived) |> D.derive (F.name prediction) model) | |
| putStrLn "Validation accuracy: " | |
| print $ | |
| computeAccuracy | |
| ( validation | |
| |> D.filterJust (F.name survived) | |
| |> D.derive (F.name prediction) model | |
| ) | |
| let predictions = D.derive (F.name survived) model test' | |
| D.writeCsv | |
| "./predictions.csv" | |
| (predictions |> D.select [F.name passengerid, F.name survived]) | |
| prediction :: D.Expr Int | |
| prediction = F.col @Int "prediction" | |
| computeAccuracy :: D.DataFrame -> Double | |
| computeAccuracy df = | |
| let | |
| tp = | |
| fromIntegral $ D.nRows (D.filterWhere (survived .== 1 .&& prediction .== 1) df) | |
| tn = | |
| fromIntegral $ D.nRows (D.filterWhere (survived .== 0 .&& prediction .== 0) df) | |
| fp = | |
| fromIntegral $ D.nRows (D.filterWhere (survived .== 0 .&& prediction .== 1) df) | |
| fn = | |
| fromIntegral $ D.nRows (D.filterWhere (survived .== 1 .&& prediction .== 0) df) | |
| in | |
| (tp + tn) / (tp + tn + fp + fn) |
Training accuracy:
0.8521617852161785
Validation accuracy:
0.8275862068965517
Model after initial search:
(ifThenElse (eq (col @maybe Text "Name") (lit (Just "Mr."))) (ifThenElse (or (geq (add (add (toDouble (col @int "Parch")) (toDouble (col @int "Pclass"))) (toDouble (col @int "Pclass"))) (lit (4.0))) (eq (col @maybe Text "Cabin") (lit (Nothing)))) (lit (0)) (ifThenElse (or (leq (col @maybe Double "Age") (col @maybe Double "Fare")) (eq (col @maybe Text "Ticket") (lit (Nothing)))) (ifThenElse (or (eq (col @maybe Text "Embarked") (lit (Just "S"))) (leq (col @maybe Double "Fare") (col @maybe Double "Age"))) (lit (0)) (lit (1))) (lit (1)))) (ifThenElse (lt (toDouble (col @int "Pclass")) (lit (3.0))) (ifThenElse (and (leq (col @maybe Double "Fare") (col @maybe Double "Age")) (eq (col @text "Sex") (lit ("male")))) (lit (0)) (lit (1))) (ifThenElse (gt (add (toDouble (col @int "Parch")) (toDouble (col @int "SibSp"))) (lit (3.0))) (lit (0)) (ifThenElse (geq (col @maybe Double "Age") (lit (Just 13.200000000000003))) (ifThenElse (lt (mult (sub (toDouble (col @int "SibSp")) (toDouble (col @int "Pclass"))) (toDouble (col @int "SibSp"))) (lit (0.0))) (lit (0)) (lit (1))) (lit (1))))))
(ifThenElse (and (lt (toDouble (col @int "passenger_class")) (lit (3.0))) (eq (col @text "Sex") (lit ("female")))) (lit (1)) (ifThenElse (eq (col @maybe Text "title") (lit (Just "Mr."))) (lit (0)) (ifThenElse (gt (add (toDouble (col @int "number_of_parents_and_children")) (toDouble (col @int "number_of_siblings_and_spouses"))) (lit (3.0))) (lit (0)) (ifThenElse (and (lt (sub (add (toDouble (col @int "number_of_parents_and_children")) (toDouble (col @int "number_of_siblings_and_spouses"))) (toDouble (col @int "passenger_class"))) (lit (0.0))) (geq (col @maybe Double "Age") (lit (Just 13.200000000000003)))) (ifThenElse (lt (sub (toDouble (col @int "number_of_siblings_and_spouses")) (toDouble (col @int "passenger_class"))) (lit (-2.0))) (lit (1)) (lit (0))) (lit (1))))))
Training accuracy:
0.8493723849372385
Validation accuracy:
0.8103448275862069
model =
( ifThenElse
(eq (col @maybe Text "Name") (lit (Just "Mr.")))
( ifThenElse
( or
( geq
( add
(add (toDouble (col @int "Parch")) (toDouble (col @int "Pclass")))
(toDouble (col @int "Pclass"))
)
(lit (4.0))
)
(eq (col @maybe Text "Cabin") (lit (Nothing)))
)
(lit (0))
( ifThenElse
( or
(leq (col @maybe Double "Age") (col @maybe Double "Fare"))
(eq (col @maybe Text "Ticket") (lit (Nothing)))
)
( ifThenElse
( or
(eq (col @maybe Text "Embarked") (lit (Just "S")))
(leq (col @maybe Double "Fare") (col @maybe Double "Age"))
)
(lit (0))
(lit (1))
)
(lit (1))
)
)
( ifThenElse
(lt (toDouble (col @int "Pclass")) (lit (3.0)))
( ifThenElse
( and
(leq (col @maybe Double "Fare") (col @maybe Double "Age"))
(eq (col @text "Sex") (lit ("male")))
)
(lit (0))
(lit (1))
)
( ifThenElse
(gt (add (toDouble (col @int "Parch")) (toDouble (col @int "SibSp"))) (lit (3.0)))
(lit (0))
( ifThenElse
(geq (col @maybe Double "Age") (lit (Just 13.200000000000003)))
( ifThenElse
( lt
( mult
(sub (toDouble (col @int "SibSp")) (toDouble (col @int "Pclass")))
(toDouble (col @int "SibSp"))
)
(lit (0.0))
)
(lit (0))
(lit (1))
)
(lit (1))
)
)
)
)