Skip to content

Instantly share code, notes, and snippets.

@mchav
Last active January 17, 2026 03:42
Show Gist options
  • Select an option

  • Save mchav/389df62cb378cb973745d0517f6973e0 to your computer and use it in GitHub Desktop.

Select an option

Save mchav/389df62cb378cb973745d0517f6973e0 to your computer and use it in GitHub Desktop.
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
import qualified Data.Text as T
import qualified DataFrame as D
import qualified DataFrame.Functions as F
import Data.Char
import Data.Text (Text)
import DataFrame ((|>))
import DataFrame.DecisionTree
import DataFrame.Functions ((.&&), (.==))
import System.Random
$(F.declareColumnsFromCsvFile "../../Downloads/train.csv")
main :: IO ()
main = do
train <- D.readCsv "../../Downloads/train.csv"
test <- D.readCsv "../../Downloads/test.csv"
let combined =
(train <> test)
|> D.derive
(F.name ticket)
(F.whenPresent (T.filter isAlpha) (F.match "^([A-Za-z][A-Za-z0-9./]*)" ticket))
|> D.derive (F.name name) (F.match "\\s*([A-Za-z]+)\\." name)
|> D.derive (F.name cabin) (F.whenPresent (T.take 1) cabin)
print combined
let (train', validation) =
D.take
(D.nRows train)
combined
|> D.shuffle (mkStdGen 1894)
|> D.randomSplit (mkStdGen 4232) 0.8
test' =
D.drop
(D.nRows train)
combined
model =
fitDecisionTree
( defaultTreeConfig
{ maxTreeDepth = 5
, minSamplesSplit = 25
, minLeafSize = 15
, synthConfig =
defaultSynthConfig
{ complexityPenalty = 0
, maxExprDepth = 3
, disallowedCombinations =
[ (F.name pclass, F.name parch)
, (F.name pclass, F.name sibsp)
, (F.name cabin, F.name ticket)
, (F.name age, F.name fare)
]
}
}
)
survived
( train'
|> D.filterJust (F.name survived)
|> D.exclude [F.name passengerid]
)
print model
putStrLn "Training accuracy: "
print $
computeAccuracy
(train' |> D.filterJust (F.name survived) |> D.derive (F.name prediction) model)
putStrLn "Validation accuracy: "
print $
computeAccuracy
( validation
|> D.filterJust (F.name survived)
|> D.derive (F.name prediction) model
)
let predictions = D.derive (F.name survived) model test'
D.writeCsv
"./predictions.csv"
(predictions |> D.select [F.name passengerid, F.name survived])
prediction :: D.Expr Int
prediction = F.col @Int "prediction"
computeAccuracy :: D.DataFrame -> Double
computeAccuracy df =
let
tp =
fromIntegral $ D.nRows (D.filterWhere (survived .== 1 .&& prediction .== 1) df)
tn =
fromIntegral $ D.nRows (D.filterWhere (survived .== 0 .&& prediction .== 0) df)
fp =
fromIntegral $ D.nRows (D.filterWhere (survived .== 0 .&& prediction .== 1) df)
fn =
fromIntegral $ D.nRows (D.filterWhere (survived .== 1 .&& prediction .== 0) df)
in
(tp + tn) / (tp + tn + fp + fn)
@mchav
Copy link
Copy Markdown
Author

mchav commented Jan 16, 2026

model =
( ifThenElse
(eq (col @maybe Text "Name") (lit (Just "Mr.")))
( ifThenElse
( or
( geq
( add
(add (toDouble (col @int "Parch")) (toDouble (col @int "Pclass")))
(toDouble (col @int "Pclass"))
)
(lit (4.0))
)
(eq (col @maybe Text "Cabin") (lit (Nothing)))
)
(lit (0))
( ifThenElse
( or
(leq (col @maybe Double "Age") (col @maybe Double "Fare"))
(eq (col @maybe Text "Ticket") (lit (Nothing)))
)
( ifThenElse
( or
(eq (col @maybe Text "Embarked") (lit (Just "S")))
(leq (col @maybe Double "Fare") (col @maybe Double "Age"))
)
(lit (0))
(lit (1))
)
(lit (1))
)
)
( ifThenElse
(lt (toDouble (col @int "Pclass")) (lit (3.0)))
( ifThenElse
( and
(leq (col @maybe Double "Fare") (col @maybe Double "Age"))
(eq (col @text "Sex") (lit ("male")))
)
(lit (0))
(lit (1))
)
( ifThenElse
(gt (add (toDouble (col @int "Parch")) (toDouble (col @int "SibSp"))) (lit (3.0)))
(lit (0))
( ifThenElse
(geq (col @maybe Double "Age") (lit (Just 13.200000000000003)))
( ifThenElse
( lt
( mult
(sub (toDouble (col @int "SibSp")) (toDouble (col @int "Pclass")))
(toDouble (col @int "SibSp"))
)
(lit (0.0))
)
(lit (0))
(lit (1))
)
(lit (1))
)
)
)
)

@mchav
Copy link
Copy Markdown
Author

mchav commented Jan 16, 2026

Training accuracy:
0.8521617852161785
Validation accuracy:
0.8275862068965517

@mchav
Copy link
Copy Markdown
Author

mchav commented Jan 16, 2026

Model after initial search:

(ifThenElse (eq (col @maybe Text "Name") (lit (Just "Mr."))) (ifThenElse (or (geq (add (add (toDouble (col @int "Parch")) (toDouble (col @int "Pclass"))) (toDouble (col @int "Pclass"))) (lit (4.0))) (eq (col @maybe Text "Cabin") (lit (Nothing)))) (lit (0)) (ifThenElse (or (leq (col @maybe Double "Age") (col @maybe Double "Fare")) (eq (col @maybe Text "Ticket") (lit (Nothing)))) (ifThenElse (or (eq (col @maybe Text "Embarked") (lit (Just "S"))) (leq (col @maybe Double "Fare") (col @maybe Double "Age"))) (lit (0)) (lit (1))) (lit (1)))) (ifThenElse (lt (toDouble (col @int "Pclass")) (lit (3.0))) (ifThenElse (and (leq (col @maybe Double "Fare") (col @maybe Double "Age")) (eq (col @text "Sex") (lit ("male")))) (lit (0)) (lit (1))) (ifThenElse (gt (add (toDouble (col @int "Parch")) (toDouble (col @int "SibSp"))) (lit (3.0))) (lit (0)) (ifThenElse (geq (col @maybe Double "Age") (lit (Just 13.200000000000003))) (ifThenElse (lt (mult (sub (toDouble (col @int "SibSp")) (toDouble (col @int "Pclass"))) (toDouble (col @int "SibSp"))) (lit (0.0))) (lit (0)) (lit (1))) (lit (1))))))

@mchav
Copy link
Copy Markdown
Author

mchav commented Jan 17, 2026

(ifThenElse (and (lt (toDouble (col @int "passenger_class")) (lit (3.0))) (eq (col @text "Sex") (lit ("female")))) (lit (1)) (ifThenElse (eq (col @maybe Text "title") (lit (Just "Mr."))) (lit (0)) (ifThenElse (gt (add (toDouble (col @int "number_of_parents_and_children")) (toDouble (col @int "number_of_siblings_and_spouses"))) (lit (3.0))) (lit (0)) (ifThenElse (and (lt (sub (add (toDouble (col @int "number_of_parents_and_children")) (toDouble (col @int "number_of_siblings_and_spouses"))) (toDouble (col @int "passenger_class"))) (lit (0.0))) (geq (col @maybe Double "Age") (lit (Just 13.200000000000003)))) (ifThenElse (lt (sub (toDouble (col @int "number_of_siblings_and_spouses")) (toDouble (col @int "passenger_class"))) (lit (-2.0))) (lit (1)) (lit (0))) (lit (1))))))
Training accuracy:
0.8493723849372385
Validation accuracy:
0.8103448275862069

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment