Skip to content

Instantly share code, notes, and snippets.

@lotz84
Created January 6, 2016 02:37
Show Gist options
  • Save lotz84/b1be4d851b45de6c2243 to your computer and use it in GitHub Desktop.
Save lotz84/b1be4d851b45de6c2243 to your computer and use it in GitHub Desktop.
k-means clustering example using iris data set https://archive.ics.uci.edu/ml/datasets/Iris
import Data.List
import Data.List.Extra
import Data.Maybe (mapMaybe)
import System.Random (randomRIO)
import qualified Data.IntMap as IntMap
import Data.Foldable (for_)
import qualified Data.Map as Map
kmeans :: Ord b => (a -> a -> b) -> ([a] -> a) -> Int -> Int -> [a] -> IO [[a]]
kmeans distance center k n xs = do
ys <- initialize k xs
pure . (!! n) . flip iterate ys $ map center >>= update distance
where
grouping :: [(a, Int)] -> [[a]]
grouping = IntMap.elems . foldr (\(v,k) -> IntMap.insertWith (++) k [v]) IntMap.empty
initialize :: Int -> [a] -> IO [[a]]
initialize k xs = grouping . zip xs <$> sequence (replicate (length xs) $ randomRIO (1, k))
update :: Ord b => (a -> a -> b) -> [a] -> [[a]] -> [[a]]
update distance ps xs = grouping $ map (\x -> (x, snd $ minimum [(distance p x, k) | (p,k) <- zip ps [1..]])) (concat xs)
data Iris = Iris
{ sepalLength :: Float
, sepalWidth :: Float
, petalLength :: Float
, petalWidth :: Float
, className :: String
}
distanceIris :: Iris -> Iris -> Float
distanceIris (Iris sl sw pl pw _) (Iris sl' sw' pl' pw' _) =
sqrt $ (sl-sl')^2 + (sw-sw')^2 + (pl-pl')^ 2 + (pw-pw')^2
centerIris :: [Iris] -> Iris
centerIris irises =
let n = genericLength irises
in Iris (sum (map sepalLength irises) / n)
(sum (map sepalWidth irises) / n)
(sum (map petalLength irises) / n)
(sum (map petalWidth irises) / n)
""
readIris :: [String] -> Maybe Iris
readIris [a,b,c,d,e] = Just $ Iris (read a) (read b) (read c) (read d) e
readIris _ = Nothing
statIris :: [Iris] -> IO ()
statIris xs = do
let m = foldr (\x m -> Map.insertWith (+) (className x) 1 m) Map.empty xs
putStrLn $ "total: " ++ show (sum $ Map.elems m)
for_ (Map.assocs m) $ \(k,v) -> putStrLn $ k ++ ": " ++ show v
putStrLn "=================\n"
main :: IO ()
main = do
irises <- mapMaybe (readIris . splitOn ",") . lines <$> readFile "iris.data"
result <- kmeans distanceIris centerIris 3 100 irises
mapM_ statIris result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment