Created
March 8, 2014 09:35
-
-
Save cympfh/9427895 to your computer and use it in GitHub Desktop.
決定木のノードを一つだけ作る
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Data.List | |
data Ayame = SV | C deriving (Show, Eq) | |
datum :: [((Int, Int), Ayame)] | |
datum = [ ((5,4), SV) , ((2,3), C) , ((9,2), SV) , ((31, 39), SV) , ((20, 20), C) ] | |
main = do | |
let rule = findRule datum | |
(left, right) = partition rule datum | |
print $ gib datum | |
print $ (left, right) | |
print $ (gib left, gib right) | |
doRule f ls = | |
sub ([], []) ls | |
where sub ac [] = ac | |
sub (xs, ys) (x:rest) = sub (if f x then (x:xs,ys) else (xs,x:ys)) rest | |
findRule ls = | |
let (g1, sh1) = partBy fst ls | |
(g2, sh2) = partBy snd ls | |
sh = if g1 < g2 then sh1 else sh2 | |
in if g1 < g2 | |
then (\x -> fst (fst x) <= sh1) | |
else (\x -> snd (fst x) <= sh2) | |
where | |
partBy f ls = | |
let gs = | |
[ (g, sh) | left <- inits ls', right <- tails ls' | |
, let n1 = length left | |
, let n2 = length right | |
, len == n1 + n2 | |
, let g = n1 * gib left + n2 * gib right | |
, let sh = div ((+) (f $ fst (last left)) (f $ fst (head right))) 2 | |
] | |
in head $ sort gs | |
where | |
ls' = sortBy comp ls | |
len = length ls | |
comp a b = compare (f $ fst a) (f (fst b)) | |
gib ls = | |
let n = length $ filter ((== SV) . snd) ls | |
m = length $ filter ((== C ) . snd) ls | |
in n * m * ( n + m ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment