Last active
August 29, 2015 14:08
-
-
Save kputnam/a65437b83187e866a201 to your computer and use it in GitHub Desktop.
Sequential and Parallel KMeans
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE ScopedTypeVariables #-} | |
module KMeans | |
( euclidean | |
, kMeans | |
, kMeansPar | |
, chunk | |
, random | |
) where | |
import Control.Monad | |
import Control.Monad.Primitive | |
import System.Random.MWC (Variate, Gen, uniform) | |
import Control.Parallel.Strategies | |
import Data.Ord | |
import Data.List | |
import Data.Vector (Vector) | |
import qualified Data.Vector as V | |
import qualified Data.Vector.Mutable as M | |
type Point a | |
= Vector a | |
type Centroid a | |
= Vector a | |
type Distance a b | |
= Point a -> Point a -> b | |
-- | Compute the squared distance between two points | |
euclidean :: Num a => Distance a a | |
euclidean a b = V.sum $ V.zipWith diffSq a b | |
where diffSq x1 x2 = (x1 - x2) ^ 2 | |
data Intermediate a | |
= Intermediate !Int !(Vector a) | |
deriving (Eq, Show, Read) | |
iempty :: Num a => Int -> Intermediate a | |
iempty n = Intermediate 0 (V.replicate n 0) | |
isingle :: Point a -> Intermediate a | |
isingle x = Intermediate 1 x | |
iinsert :: Num a => Intermediate a -> Point a -> Intermediate a | |
iinsert (Intermediate n x) x' = Intermediate (n + 1) (V.zipWith (+) x x') | |
iappend :: Num a => Intermediate a -> Intermediate a -> Intermediate a | |
iappend (Intermediate n x) (Intermediate n' x') = Intermediate (n + n') (V.zipWith (+) x x') | |
icentroid :: Fractional a => Intermediate a -> Centroid a | |
icentroid (Intermediate n x) = fmap (/ (fromIntegral n)) x | |
-- | Lloyd's algorithm | |
kMeans :: forall a b. (Eq a, Fractional a, Ord b) | |
=> Int -- ^ maximum number of iterations | |
-> Distance a b -- ^ distance function between points | |
-> [Centroid a] -- ^ initial guess at centroids | |
-> [Point a] -- ^ set of data points to cluster | |
-> [Centroid a] | |
kMeans maxIterations distance cs xs = loop maxIterations cs | |
where | |
nClusters = length cs | |
nDimensions = V.length (head xs) | |
-- | Stop when centroids stop moving or we exhausted maximum number of iterations | |
loop n cs | |
| n <= 0 = cs | |
| cs == cs' = cs | |
| otherwise = loop (n-1) cs' | |
where cs' = clusters (update cs xs) | |
-- | For each point, select nearest centroid and update number of points and component-wise sums | |
update cs xs = V.create $ do | |
-- Initialize each cluster with Intermediate 0 0-vector | |
results <- M.replicate nClusters (iempty nDimensions) | |
-- Update nearest cluster's Intermediate value | |
forM_ xs $ \x -> do | |
let k = fst $ nearest x | |
i <- M.read results k | |
M.write results k (iinsert i x) | |
return results | |
where | |
centroids = zip [0..] cs | |
nearest x = minimumBy (comparing (distance x . snd)) centroids | |
-- | Given numbers of points and component-wise sums, compute centroids | |
clusters :: Vector (Intermediate a) -> [Centroid a] | |
clusters cs = [ icentroid i | i@(Intermediate n _) <- V.toList cs, n > 0 ] | |
-- | Break a vector into n chunks (in O(nChunks) time) | |
chunk :: Int -> Vector a -> [Vector a] | |
chunk n xs = zipWith slice [0..n-1] (extra:repeat 0) | |
where | |
slice k e = V.slice (size*k) (size+e) xs | |
(size, extra) = V.length xs `quotRem` n | |
-- | Lloyd's algorithm | |
kMeansPar :: forall a b. (Eq a, Fractional a, Ord b) | |
=> Int -- ^ maximum number of iterations | |
-> Distance a b -- ^ distance measure between two points | |
-> [Centroid a] -- ^ initial guess at centroids | |
-> [Vector (Point a)] -- ^ evenly-divided groups of points to cluster | |
-> [Centroid a] | |
kMeansPar maxIterations distance cs xs = loop maxIterations cs | |
where | |
nClusters = length cs | |
nDimensions = V.length (V.head (head xs)) | |
-- | Stop when centroids stop moving or we exhausted maximum number of iterations | |
loop n cs | |
| n <= 0 = cs | |
| cs == cs' = cs | |
| otherwise = loop (n-1) cs' | |
where cs' = clusters $ foldr1 (V.zipWith iappend) (map (update cs) xs `using` parList rseq) | |
-- | For each point, select nearest centroid and update number of points and component-wise sums | |
update :: [Centroid a] -> Vector (Point a) -> Vector (Intermediate a) | |
update cs xs = V.create $ do | |
-- Initialize each cluster with Intermediate 0 0-vector | |
results <- M.replicate nClusters (iempty nDimensions) | |
-- Update nearest cluster's Intermediate value | |
V.forM_ xs $ \x -> do | |
let k = fst $ nearest x | |
i <- M.read results k | |
M.write results k (iinsert i x) | |
return results | |
where | |
centroids = zip [0..] cs | |
nearest x = minimumBy (comparing (distance x . snd)) centroids | |
-- | Given numbers of points and component-wise sums, compute centroids | |
clusters :: Vector (Intermediate a) -> [Centroid a] | |
clusters cs = [ icentroid i | i@(Intermediate n _) <- V.toList cs, n > 0 ] | |
-- | Generate a uniformly-distributed centroid | |
random :: (PrimMonad m, Variate a, Num a) => Int -> Gen (PrimState m) -> m (Centroid a) | |
random nDimensions gen = V.replicateM nDimensions (uniform gen) | |
--------------------------------------------------------------------------------------- | |
let point x y z g = do { x' <- normal x 10 g; y' <- normal y 10 g; z' <- normal z 10 g; return $ V.fromList [x',y',z'] } | |
g <- createSystemRandom | |
-- Random data normally distributed around four centroids | |
as <- replicateM 200 $ point 100 100 100 g | |
bs <- replicateM 150 $ point 50 (-10) (-10) g | |
cs <- replicateM 130 $ point 100 (-10) 15 g | |
ds <- replicateM 230 $ point (-20) 80 180 g | |
let xs = mconcat [as,bs,cs,ds] | |
-- Random initial centroid locations | |
zs <- replicateM 4 (random 3 g) | |
-- [fromList [0.9848902563026655,0.765287551648172,0.7781795312633198] | |
-- ,fromList [0.7361694755398108,7.367420245929768e-2,0.6224181897188118] | |
-- ,fromList [0.8774773711990987,0.4621511922427829,0.5492891219506643] | |
-- ,fromList [0.43748278208197944,0.8174101215438477,0.8235959386276283]] | |
kMeans 1000 euclidean zs xs | |
-- [fromList [99.89809340402812,100.38225039289729,100.24009968006065] | |
-- ,fromList [49.761342796903456,-8.898956628199413,-9.947110251417078] | |
-- ,fromList [100.33603708616106,-9.597087217849067,14.854017607086883] | |
-- ,fromList [-20.36856619481266,79.89870265298171,179.94788176235113]] | |
kMeansPar 1000 euclidean zs (chunk 4 $ V.fromList xs) | |
-- [fromList [99.89809340402812,100.38225039289729,100.24009968006065] | |
-- ,fromList [49.761342796903456,-8.898956628199413,-9.947110251417078] | |
-- ,fromList [100.33603708616106,-9.597087217849067,14.854017607086883] | |
-- ,fromList [-20.36856619481266,79.89870265298171,179.94788176235113]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment