Last active
December 26, 2019 07:35
-
-
Save lotz84/4f6ae2816da6a9f5c81b29bba3844873 to your computer and use it in GitHub Desktop.
"Fast and Accurate Least-Mean-Squares Solvers" in Haskell https://arxiv.org/abs/1906.04705
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE NoStarIsType #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE TypeApplications #-} | |
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE TypeFamilies #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} | |
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} | |
module Main where | |
import Prelude hiding ((<>)) | |
import Data.List | |
import Data.Maybe | |
import Data.Proxy | |
import GHC.TypeLits | |
import Data.IntMap (IntMap) | |
import qualified Data.IntMap as IntMap | |
import qualified Numeric.LinearAlgebra as H | |
import Numeric.LinearAlgebra.Static | |
import Numeric.LinearAlgebra.Static.Vector | |
import System.Random.MWC hiding (create) | |
type WeightedPoint d = (R d, Double) | |
chunksOf :: Int -> [a] -> [[a]] | |
chunksOf _ [] = [] | |
chunksOf k as = take k as : chunksOf k (drop k as) | |
caratheodorySet :: forall d. KnownNat d => IntMap (WeightedPoint d) -> IntMap (WeightedPoint d) | |
caratheodorySet ps = | |
let d = fromIntegral $ natVal (Proxy @d) | |
in if IntMap.size ps <= d + 1 | |
then ps | |
else let us = IntMap.map snd ps | |
((k1, (p1, _)), ps') = IntMap.deleteFindMin ps | |
as = IntMap.map (subtract p1 . fst) $ ps' | |
vs = IntMap.fromList | |
$ zip (map fst $ IntMap.toAscList as) | |
$ (map snd $ IntMap.toAscList as) | |
`withColumns` | |
(`withNullspace` (H.toList . extract . head . toColumns)) | |
vs' = IntMap.insert k1 (- sum vs) vs | |
a = minimum . map snd . IntMap.toList . IntMap.filter (>0) $ IntMap.unionWith (/) us vs' | |
ws = IntMap.unionWith (-) us $ IntMap.map (* a) vs' | |
ss = IntMap.filter (\(_, w) -> w > 0) $ IntMap.intersectionWith (,) (IntMap.map fst ps) ws | |
in caratheodorySet ss | |
fastCaratheodorySet :: forall d. KnownNat d => Int -> IntMap (WeightedPoint d) -> IntMap (WeightedPoint d) | |
fastCaratheodorySet k ps = | |
let d = fromIntegral $ natVal (Proxy @d) | |
in if IntMap.size ps <= d + 1 | |
then ps | |
else let partitions = IntMap.fromList $ zip [1..] $ chunksOf (length ps `div` k) (IntMap.toList ps) | |
us = IntMap.map (sum . map (snd . snd)) partitions | |
ms = IntMap.map (\(u, xs) -> (map (fst . snd) xs) `withColumns` (\m -> m #> (vector $ map ((/u) . snd . snd) xs))) $ IntMap.intersectionWith (,) us partitions | |
cs = caratheodorySet $ IntMap.intersectionWith (,) ms us | |
ss = IntMap.fromList . concat . map snd . IntMap.toList | |
$ IntMap.map (\(u', ((m, w), xs)) -> map (\(k, (p, u)) -> (k, (p, w * u / u'))) xs) | |
$ IntMap.intersectionWith (,) us $ IntMap.intersectionWith (,) cs partitions | |
in fastCaratheodorySet k ss | |
caratheodoryMatrix :: forall d. KnownNat d => Int -> [R d] -> [R d] | |
caratheodoryMatrix k ps = | |
let d = fromIntegral $ natVal (Proxy @d) | |
n = fromIntegral $ length ps | |
ps' = IntMap.fromList $ zip [1..] ps | |
cs = fastCaratheodorySet k | |
$ IntMap.map (\v -> (fromJust . create . H.flatten . extract $ v `outer` v, 1 / n)) | |
$ ps' :: IntMap (WeightedPoint (d * d)) | |
ss = IntMap.map (\((_, w), p) -> dvmap (*(sqrt (n * w))) p) $ IntMap.intersectionWith (,) cs ps' | |
in map snd $ IntMap.toAscList ss | |
covariance m = tr m <> m | |
randVec :: forall n. KnownNat n => IO (R n) | |
randVec = | |
let n = fromIntegral $ natVal (Proxy @n) | |
in vector <$> (withSystemRandom . asGenIO $ (sequence . replicate n . uniformR (0, 1 :: Double))) | |
main :: IO () | |
main = do | |
let sampleN = 1000 | |
ps <- sequence $ replicate sampleN (randVec :: IO (R 2)) | |
us <- withSystemRandom . asGenIO $ \gen -> do | |
us <- sequence . replicate sampleN $ uniformR (0, 1 :: Double) gen | |
pure $ map (/ sum us) us | |
mapM_ print (zip ps us) | |
putStrLn "分散共分散行列(Before)" | |
print $ ps `withRows` covariance | |
putStrLn "分散共分散行列(After)" | |
print $ (caratheodoryMatrix 6 ps) `withRows` covariance | |
putStr "内分点: " | |
print $ ps `withColumns` (\m -> m #> vector us) | |
let (ss, ws) = unzip $ map snd $ IntMap.toList $ caratheodorySet $ IntMap.fromList $ zip [1.. ]$ zip ps us | |
putStrLn "カラテオドリ集合:_" | |
mapM_ print (zip ss ws) | |
putStr "内分点: " | |
print $ ss `withColumns` (\m -> m #> vector ws) | |
let (ss, ws) = unzip $ map snd $ IntMap.toList $ fastCaratheodorySet 4 $ IntMap.fromList $ zip [1.. ]$ zip ps us | |
putStrLn "カラテオドリ集合(fast):_" | |
mapM_ print (zip ss ws) | |
putStr "内分点: " | |
print $ ss `withColumns` (\m -> m #> vector ws) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
... | |
executables: | |
caratheodory-set-exe: | |
main: Main.hs | |
source-dirs: app | |
ghc-options: | |
- -threaded | |
- -rtsopts | |
- -with-rtsopts=-N | |
dependencies: | |
- containers | |
- ghc-typelits-knownnat | |
- ghc-typelits-natnormalise | |
- hmatrix | |
- hmatrix-vector-sized | |
- mwc-random | |
- vector-sized | |
- caratheodory-set |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment